import warnings
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
from ._pytblis_impl import add, mult
from .typecheck import _accepted_types, _check_strides, _check_tblis_types, _valid_labels
scalar = Union[float, complex]
[docs]
def transpose_add(
subscripts: str,
a: npt.ArrayLike,
alpha: scalar = 1.0,
beta: scalar = 0.0,
out: Optional[npt.ArrayLike] = None,
conja: bool = False,
conjout: bool = False,
) -> npt.ArrayLike:
"""
Perform tensor transpose and addition based on the provided subscripts.
High-level wrapper for tblis_tensor_add.
B (stored in `out` if provided) is computed as:
B = alpha * transpose(a, subscripts) + beta * B.
Optionally, conjugation can be applied to `a` and `out`.
Parameters
----------
subscripts : str
Subscripts defining the contraction.
a : array_like
Tensor operand
alpha : scalar, optional
Scaling factor for `a`.
beta : scalar, optional
Scaling factor for the output tensor `b`. Must be 0.0 if `out` is None.
conja: bool, optional
If True, conjugate the first tensor `a`. Alpha is not conjugated.
conjout: bool, optional
If True, conjugate the output tensor `b`. Beta is not conjugated.
out : array_like, optional
Output tensor containing `b`.
Returns
-------
ndarray
Result of the tensor contraction.
Examples
--------
>>> import numpy as np
>>> from pytblis import transpose_add
>>> a = np.random.rand(3, 4, 5)
>>> b = transpose_add("ijk->ikj", a, alpha=1.0)
"""
a = np.asarray(a)
scalar_type = _check_tblis_types(a, out) if out is not None else _check_tblis_types(a)
strides_ok = _check_strides(a, out) if out is not None else _check_strides(a)
if scalar_type is None:
raise TypeError(
"TBLIS only supports float32, float64, complex64, and complex128. "
"Types do not match or unsupported type detected. "
)
if not strides_ok:
msg = "Input tensor has non-positive strides."
raise ValueError(msg)
subscripts = subscripts.replace(" ", "")
a_idx, b_idx = subscripts.split("->")
# a_idx, b_idx, c_idx = re.split(",|->", subscripts)
if not set(a_idx) >= set(b_idx):
msg = f"Invalid subscripts '{subscripts}'"
raise ValueError(msg)
a_shape_dic = dict(zip(a_idx, a.shape))
b_shape = tuple(a_shape_dic[x] for x in b_idx)
if out is None:
out = np.empty(b_shape, dtype=scalar_type)
assert beta == 0.0, "beta must be 0.0 if out is None"
else:
out = np.asarray(out)
if out.shape != b_shape:
msg = f"Output shape {out.shape} does not match expected shape {b_shape} for subscripts '{subscripts}'"
raise ValueError(msg)
add(a, out, a_idx, b_idx, alpha=alpha, beta=beta, conja=conja, conjb=conjout)
return out
[docs]
def contract(
subscripts: str,
a: npt.ArrayLike,
b: npt.ArrayLike,
alpha: scalar = 1.0,
beta: scalar = 0.0,
out: Optional[npt.ArrayLike] = None,
conja: bool = False,
conjb: bool = False,
) -> npt.ArrayLike:
"""
Perform tensor contraction based on the provided subscripts.
C (stored in `out` if provided) is computed as:
C = alpha * einsum(subscripts, a, b) + beta * C if `out` is provided.
Parameters
----------
subscripts : str
Subscripts defining the contraction.
a : array_like
First tensor operand.
b : array_like
Second tensor operand.
alpha : scalar, optional
Scaling factor for the product of `a` and `b`.
beta : scalar, optional
Scaling factor for the output tensor. Must be 0.0 if `out` is None.
conja: bool, optional
If True, conjugate the first tensor `a` before contraction. Alpha is not conjugated.
conjb: bool, optional
If True, conjugate the second tensor `b` before contraction. Beta is not conjugated.
out : array_like, optional
Output tensor to store the result.
Returns
-------
ndarray
Result of the tensor contraction.
"""
a = np.asarray(a)
b = np.asarray(b)
scalar_type = _check_tblis_types(a, b, out) if out is not None else _check_tblis_types(a, b)
strides_ok = _check_strides(a, b, out) if out is not None else _check_strides(a, b)
if scalar_type is None or not strides_ok:
if scalar_type is None:
warnings.warn(
"TBLIS only supports float32, float64, complex64, and complex128. "
"Types do not match or unsupported type detected. "
"Will attempt to fall back to numpy tensordot.",
stacklevel=2,
)
if not strides_ok:
warnings.warn(
"Input tensor has non-positive strides. Will attempt to fall back to numpy tensordot.", stacklevel=2
)
if alpha != 1.0 or beta != 0.0:
msg = "Cannot fall back to numpy tensordot unless alpha = 1.0 and beta = 0.0"
raise ValueError(msg)
return np.einsum(subscripts, a, b)
subscripts = subscripts.replace(" ", "")
input_str, c_idx = subscripts.split("->")
a_idx, b_idx = input_str.split(",")
# a_idx, b_idx, c_idx = re.split(",|->", subscripts)
if not (set(a_idx) | set(b_idx)) >= set(c_idx):
msg = f"Invalid subscripts '{subscripts}'"
raise ValueError(msg)
a_shape_dic = dict(zip(a_idx, a.shape))
b_shape_dic = dict(zip(b_idx, b.shape))
if any(a_shape_dic[x] != b_shape_dic[x] for x in set(a_idx) & set(b_idx)):
msg = f"Shape mismatch for subscripts '{subscripts}': {a.shape} {b.shape}"
raise ValueError(msg)
ab_shape_dic = {**a_shape_dic, **b_shape_dic}
c_shape = tuple(ab_shape_dic[x] for x in c_idx)
if out is None:
out = np.empty(c_shape, dtype=scalar_type)
assert beta == 0.0, "beta must be 0.0 if out is None"
else:
out = np.asarray(out)
if out.shape != c_shape:
msg = f"Output shape {out.shape} does not match expected shape {c_shape} for subscripts '{subscripts}'"
raise ValueError(msg)
mult(a, b, out, a_idx, b_idx, c_idx, alpha=alpha, beta=beta, conja=conja, conjb=conjb)
return out
[docs]
def ascontiguousarray(a):
"""Parallel transpose the input to C-contiguous layout.
Parameters
----------
a : array_like
Input array.
Returns
-------
ndarray
Contiguous array.
"""
a = np.asarray(a)
if not _check_strides(a):
warnings.warn("Input tensor has non-positive strides. Falling back to numpy ascontiguousarray.", stacklevel=2)
return np.ascontiguousarray(a)
if a.flags.c_contiguous:
return a
if a.dtype.type not in _accepted_types:
warnings.warn(
"TBLIS only supports float32, float64, complex64, and complex128. Falling back to numpy ascontiguousarray.",
stacklevel=2,
)
return np.ascontiguousarray(a)
out = np.empty(a.shape, dtype=a.dtype, order="C")
assert len(a.shape) < len(_valid_labels), (
f"a.ndim is {len(a.shape)}, but only {len(_valid_labels)} labels are valid."
)
a_inds = _valid_labels[: len(a.shape)]
a_inds = "".join(a_inds)
add(a, out, a_inds, a_inds, alpha=1.0, beta=0.0)
return out