import warnings
from typing import Optional, Union
import numpy as np
import numpy.typing as npt
from ._pytblis_impl import add, mult, shift
from .typecheck import _accepted_types, _check_strides, _check_tblis_types, _valid_labels
scalar = Union[float, complex]
[docs]
def transpose_add(
subscripts: str,
a: npt.ArrayLike,
alpha: scalar = 1.0,
beta: scalar = 0.0,
out: Optional[npt.ArrayLike] = None,
conja: bool = False,
conjout: bool = False,
) -> npt.ArrayLike:
"""
Perform tensor transpose and addition based on the provided subscripts.
High-level wrapper for tblis_tensor_add.
B (stored in `out` if provided) is computed as:
B = alpha * transpose(a, subscripts) + beta * B.
Optionally, conjugation can be applied to `a` and `out`.
Parameters
----------
subscripts : str
Subscripts defining the contraction.
a : array_like
Tensor operand
alpha : scalar, optional
Scaling factor for `a`.
beta : scalar, optional
Scaling factor for the output tensor `b`. Must be 0.0 if `out` is None.
conja: bool, optional
If True, conjugate the first tensor `a`. Alpha is not conjugated.
conjout: bool, optional
If True, conjugate the output tensor `b`. Beta is not conjugated.
out : array_like, optional
Output tensor containing `b`.
Returns
-------
ndarray
Result of the tensor contraction.
Examples
--------
>>> import numpy as np
>>> from pytblis import transpose_add
>>> a = np.random.rand(3, 4, 5)
>>> b = transpose_add("ijk->ikj", a, alpha=1.0)
"""
a = np.asarray(a)
scalar_type = _check_tblis_types(a, out) if out is not None else _check_tblis_types(a)
input_strides_ok, output_strides_ok = _check_strides(a, out=out)
if scalar_type is None:
raise TypeError(
"TBLIS only supports float32, float64, complex64, and complex128. "
"Types do not match or unsupported type detected. "
)
# It's okay if A has bad strides if its size is 0, because nothing will be read.
if not input_strides_ok and a.size != 0:
msg = f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}"
raise ValueError(msg)
if not output_strides_ok and out.size != 0:
msg = f"Output tensor of shape {out.shape} has non-positive strides: {out.strides}"
raise ValueError(msg)
subscripts = subscripts.replace(" ", "")
a_idx, b_idx = subscripts.split("->")
if not set(a_idx) >= set(b_idx):
msg = f"Invalid subscripts '{subscripts}'"
raise ValueError(msg)
a_shape_dic = dict(zip(a_idx, a.shape))
b_shape = tuple(a_shape_dic[x] for x in b_idx)
if out is None:
out = np.empty(b_shape, dtype=scalar_type)
assert beta == 0.0, "beta must be 0.0 if out is None"
else:
out = np.asarray(out)
if out.shape != b_shape:
msg = f"Output shape {out.shape} does not match expected shape {b_shape} for subscripts '{subscripts}'"
raise ValueError(msg)
# handle zero-sized input
# if a is size zero, tensor_add quits early and does not scale B.
if a.size == 0:
shift(out, b_idx, alpha=0.0, beta=beta)
else:
add(a, out, a_idx, b_idx, alpha=alpha, beta=beta, conja=conja, conjb=conjout)
return out
[docs]
def contract(
subscripts: str,
a: npt.ArrayLike,
b: npt.ArrayLike,
alpha: scalar = 1.0,
beta: scalar = 0.0,
out: Optional[npt.ArrayLike] = None,
conja: bool = False,
conjb: bool = False,
allow_partial_trace: bool = False,
) -> npt.ArrayLike:
"""
Perform tensor contraction based on the provided subscripts.
C (stored in `out` if provided) is computed as:
C = alpha * einsum(subscripts, a, b) + beta * C if `out` is provided.
Parameters
----------
subscripts : str
Subscripts defining the contraction.
a : array_like
First tensor operand.
b : array_like
Second tensor operand.
alpha : scalar, optional
Scaling factor for the product of `a` and `b`.
beta : scalar, optional
Scaling factor for the output tensor. Must be 0.0 if `out` is None.
conja: bool, optional
If True, conjugate the first tensor `a` before contraction. Alpha is not conjugated.
conjb: bool, optional
If True, conjugate the second tensor `b` before contraction. Beta is not conjugated.
allow_partial_trace : bool, optional
If True, handle redundant indices in subscripts for `a` and `b` by doing partial trace before contraction.
out : array_like, optional
Output tensor to store the result.
Returns
-------
ndarray
Result of the tensor contraction.
"""
a = np.asarray(a)
b = np.asarray(b)
scalar_type = _check_tblis_types(a, b, out=out)
input_strides_ok, output_strides_ok = _check_strides(a, b, out=out)
is_trivial = a.size == 0 or b.size == 0
fallback = False
if scalar_type is None:
warnings.warn(
"TBLIS only supports float32, float64, complex64, and complex128. "
"Types do not match or unsupported type detected. "
"Will attempt to fall back to numpy tensordot.",
stacklevel=2,
)
fallback = True
if not output_strides_ok and not is_trivial:
warnings.warn(
f"Output tensor of shape {out.shape} has non-positive strides: {out.strides}. "
"Will attempt to fall back to numpy tensordot.",
stacklevel=2,
)
fallback = True
if not input_strides_ok and not is_trivial:
warnings.warn(
f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}. "
"Will attempt to fall back to numpy tensordot.",
stacklevel=2,
)
fallback = True
if fallback:
if alpha != 1.0 or beta != 0.0:
msg = "Cannot fall back to numpy tensordot unless alpha = 1.0 and beta = 0.0"
raise ValueError(msg)
return np.einsum(subscripts, a, b)
subscripts = subscripts.replace(" ", "")
input_str, subscript_c = subscripts.split("->")
subscript_a, subscript_b = input_str.split(",")
idx_a = frozenset(subscript_a)
idx_b = frozenset(subscript_b)
idx_c = frozenset(subscript_c)
idx_redundant_a = idx_a - idx_b - idx_c
idx_redundant_b = idx_b - idx_a - idx_c
idx_redundant_c = idx_c - idx_a - idx_b
if idx_redundant_c:
raise RuntimeError("Should never have redundant indices in the output. This is probably a bug.")
if not allow_partial_trace and (idx_redundant_a or idx_redundant_b):
msg = (
f"Subscripts '{subscripts}' require partial trace on "
f"{'a ' if idx_redundant_a else ''}{'b' if idx_redundant_b else ''}. "
"Pass allow_partial_trace=True to enable."
)
raise ValueError(msg)
if idx_redundant_a:
# partial trace on a
subscript_a_traced = "".join([i for i in subscript_a if i not in idx_redundant_a])
einsum_str_traced = f"{subscript_a}->{subscript_a_traced}"
a = transpose_add(einsum_str_traced, a)
subscript_a = subscript_a_traced
if idx_redundant_b:
# partial trace on b
subscript_b_traced = "".join([i for i in subscript_b if i not in idx_redundant_b])
einsum_str_traced = f"{subscript_b}->{subscript_b_traced}"
b = transpose_add(einsum_str_traced, b)
subscript_b = subscript_b_traced
if not (set(subscript_a) | set(subscript_b)) >= set(subscript_c):
msg = f"Invalid subscripts '{subscripts}'"
raise ValueError(msg)
a_shape_dic = dict(zip(subscript_a, a.shape))
b_shape_dic = dict(zip(subscript_b, b.shape))
if any(a_shape_dic[x] != b_shape_dic[x] for x in set(subscript_a) & set(subscript_b)):
msg = f"Shape mismatch for subscripts '{subscripts}': {a.shape} {b.shape}"
raise ValueError(msg)
ab_shape_dic = {**a_shape_dic, **b_shape_dic}
c_shape = tuple(ab_shape_dic[x] for x in subscript_c)
if out is None:
out = np.empty(c_shape, dtype=scalar_type)
assert beta == 0.0, "beta must be 0.0 if out is None"
else:
out = np.asarray(out)
if out.shape != c_shape:
msg = f"Output shape {out.shape} does not match expected shape {c_shape} for subscripts '{subscripts}'"
raise ValueError(msg)
# handle zero-sized input
# if A or B is size zero, mult quits early and does not scale C.
if is_trivial:
shift(out, subscript_c, alpha=0.0, beta=beta)
else:
mult(a, b, out, subscript_a, subscript_b, subscript_c, alpha=alpha, beta=beta, conja=conja, conjb=conjb)
return out
[docs]
def ascontiguousarray(a):
"""Parallel transpose the input to C-contiguous layout.
Parameters
----------
a : array_like
Input array.
Returns
-------
ndarray
Contiguous array.
"""
a = np.asarray(a)
if a.flags.c_contiguous:
return a
if not _check_strides(a):
warnings.warn(
f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}. Falling back to numpy ascontiguousarray.",
stacklevel=2,
)
return np.ascontiguousarray(a)
if a.dtype.type not in _accepted_types:
warnings.warn(
"TBLIS only supports float32, float64, complex64, and complex128. Falling back to numpy ascontiguousarray.",
stacklevel=2,
)
return np.ascontiguousarray(a)
out = np.empty(a.shape, dtype=a.dtype, order="C")
assert len(a.shape) < len(_valid_labels), (
f"a.ndim is {len(a.shape)}, but only {len(_valid_labels)} labels are valid."
)
a_inds = _valid_labels[: len(a.shape)]
a_inds = "".join(a_inds)
add(a, out, a_inds, a_inds, alpha=1.0, beta=0.0)
return out