Source code for pytblis.wrappers

import warnings
from typing import Optional, Union

import numpy as np
import numpy.typing as npt

from ._pytblis_impl import add, mult, shift
from .defaultorder import get_default_array_order
from .typecheck import _accepted_types, _check_strides, _check_tblis_types, _valid_labels, contraction_result_shape

scalar = Union[float, complex]


[docs] def transpose_add( subscripts: str, a: npt.ArrayLike, alpha: scalar = 1.0, beta: scalar = 0.0, out: Optional[npt.ArrayLike] = None, conja: bool = False, conjout: bool = False, ) -> npt.ArrayLike: """ Perform tensor transpose and addition based on the provided subscripts. High-level wrapper for tblis_tensor_add. B (stored in `out` if provided) is computed as: B = alpha * transpose(a, subscripts) + beta * B. Optionally, conjugation can be applied to `a` and `out`. Parameters ---------- subscripts : str Subscripts defining the contraction. a : array_like Tensor operand alpha : scalar, optional Scaling factor for `a`. beta : scalar, optional Scaling factor for the output tensor `b`. Must be 0.0 if `out` is None. conja: bool, optional If True, conjugate the first tensor `a`. Alpha is not conjugated. conjout: bool, optional If True, conjugate the output tensor `b`. Beta is not conjugated. out : array_like, optional Output tensor containing `b`. Returns ------- ndarray Result of the tensor contraction. Examples -------- >>> import numpy as np >>> from pytblis import transpose_add >>> a = np.random.rand(3, 4, 5) >>> b = transpose_add("ijk->ikj", a, alpha=1.0) """ a = np.asarray(a) scalar_type = _check_tblis_types(a, out) if out is not None else _check_tblis_types(a) input_strides_ok, output_strides_ok = _check_strides(a, out=out) if scalar_type is None: raise TypeError( "TBLIS only supports float32, float64, complex64, and complex128. " "Types do not match or unsupported type detected. " ) # It's okay if A has bad strides if its size is 0, because nothing will be read. if not input_strides_ok and a.size != 0: msg = f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}" raise ValueError(msg) if not output_strides_ok and out.size != 0: msg = f"Output tensor of shape {out.shape} has non-positive strides: {out.strides}" raise ValueError(msg) subscripts = subscripts.replace(" ", "") a_idx, b_idx = subscripts.split("->") if not set(a_idx) >= set(b_idx): msg = f"Invalid subscripts '{subscripts}'" raise ValueError(msg) a_shape_dic = dict(zip(a_idx, a.shape)) b_shape = tuple(a_shape_dic[x] for x in b_idx) if out is None: out = np.empty(b_shape, dtype=scalar_type, order=get_default_array_order()) assert beta == 0.0, "beta must be 0.0 if out is None" else: out = np.asarray(out) if out.shape != b_shape: msg = f"Output shape {out.shape} does not match expected shape {b_shape} for subscripts '{subscripts}'" raise ValueError(msg) # handle zero-sized input # if a is size zero, tensor_add quits early and does not scale B. if a.size == 0: shift(out, b_idx, alpha=0.0, beta=beta) else: add(a, out, a_idx, b_idx, alpha=alpha, beta=beta, conja=conja, conjb=conjout) return out
def contract_same_type( subscripts: str, a: npt.ArrayLike, b: npt.ArrayLike, alpha: scalar = 1.0, beta: scalar = 0.0, out: Optional[npt.ArrayLike] = None, conja: bool = False, conjb: bool = False, allow_partial_trace: bool = False, ) -> npt.ArrayLike: """ Perform tensor contraction based on the provided subscripts. C (stored in `out` if provided) is computed as: C = alpha * einsum(subscripts, a, b) + beta * C if `out` is provided. Parameters ---------- subscripts : str Subscripts defining the contraction. a : array_like First tensor operand. b : array_like Second tensor operand. alpha : scalar, optional Scaling factor for the product of `a` and `b`. beta : scalar, optional Scaling factor for the output tensor. Must be 0.0 if `out` is None. conja: bool, optional If True, conjugate the first tensor `a` before contraction. Alpha is not conjugated. conjb: bool, optional If True, conjugate the second tensor `b` before contraction. Beta is not conjugated. allow_partial_trace : bool, optional If True, handle redundant indices in subscripts for `a` and `b` by doing partial trace before contraction. out : array_like, optional Output tensor to store the result. Returns ------- ndarray Result of the tensor contraction. """ a = np.asarray(a) b = np.asarray(b) scalar_type = _check_tblis_types(a, b, out=out) input_strides_ok, output_strides_ok = _check_strides(a, b, out=out) is_trivial = a.size == 0 or b.size == 0 fallback = False if scalar_type is None: warnings.warn( "TBLIS only supports float32, float64, complex64, and complex128. " "Types do not match or unsupported type detected. " "Will attempt to fall back to numpy tensordot.", stacklevel=2, ) fallback = True if not output_strides_ok and not is_trivial: warnings.warn( f"Output tensor of shape {out.shape} has non-positive strides: {out.strides}. " "Will attempt to fall back to numpy tensordot.", stacklevel=2, ) fallback = True if not input_strides_ok and not is_trivial: warnings.warn( f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}. " "Will attempt to fall back to numpy tensordot.", stacklevel=2, ) fallback = True if fallback: if alpha != 1.0 or beta != 0.0: msg = "Cannot fall back to numpy tensordot unless alpha = 1.0 and beta = 0.0" raise ValueError(msg) return np.einsum(subscripts, a, b) subscripts = subscripts.replace(" ", "") input_str, subscript_c = subscripts.split("->") subscript_a, subscript_b = input_str.split(",") idx_a = frozenset(subscript_a) idx_b = frozenset(subscript_b) idx_c = frozenset(subscript_c) idx_redundant_a = idx_a - idx_b - idx_c idx_redundant_b = idx_b - idx_a - idx_c idx_redundant_c = idx_c - idx_a - idx_b if idx_redundant_c: raise RuntimeError("Should never have redundant indices in the output. This is probably a bug.") if not allow_partial_trace and (idx_redundant_a or idx_redundant_b): msg = ( f"Subscripts '{subscripts}' require partial trace on " f"{'a ' if idx_redundant_a else ''}{'b' if idx_redundant_b else ''}. " "Pass allow_partial_trace=True to enable." ) raise ValueError(msg) c_shape = contraction_result_shape(subscripts, a.shape, b.shape) if idx_redundant_a: # partial trace on a subscript_a_traced = "".join([i for i in subscript_a if i not in idx_redundant_a]) einsum_str_traced = f"{subscript_a}->{subscript_a_traced}" a = transpose_add(einsum_str_traced, a) subscript_a = subscript_a_traced if idx_redundant_b: # partial trace on b subscript_b_traced = "".join([i for i in subscript_b if i not in idx_redundant_b]) einsum_str_traced = f"{subscript_b}->{subscript_b_traced}" b = transpose_add(einsum_str_traced, b) subscript_b = subscript_b_traced if out is None: out = np.empty(c_shape, dtype=scalar_type, order=get_default_array_order()) assert beta == 0.0, "beta must be 0.0 if out is None" else: out = np.asarray(out) if out.shape != c_shape: msg = f"Output shape {out.shape} does not match expected shape {c_shape} for subscripts '{subscripts}'" raise ValueError(msg) # handle zero-sized input # if A or B is size zero, mult quits early and does not scale C. if is_trivial: shift(out, subscript_c, alpha=0.0, beta=beta) else: mult(a, b, out, subscript_a, subscript_b, subscript_c, alpha=alpha, beta=beta, conja=conja, conjb=conjb) return out
[docs] def contract( subscripts: str, a: npt.ArrayLike, b: npt.ArrayLike, alpha: scalar = 1.0, beta: scalar = 0.0, out: Optional[npt.ArrayLike] = None, conja: bool = False, conjb: bool = False, allow_partial_trace: bool = False, complex_real_contractions: bool = True, ) -> npt.ArrayLike: """ Similar to `pytblis.contract`, but with support for contractions between complex and real tensors. If one of the input tensors is complex and the other is real, we perform separate contractions for the real and imaginary parts of the complex tensor. C (stored in `out` if provided) is computed as: C = alpha * einsum(subscripts, a, b) + beta * C if `out` is provided. Parameters ---------- subscripts : str Subscripts defining the contraction. a : array_like First tensor operand. b : array_like Second tensor operand. alpha : scalar, optional Scaling factor for the product of `a` and `b`. beta : scalar, optional Scaling factor for the output tensor. Must be 0.0 if `out` is None. conja: bool, optional If True, conjugate the first tensor `a` before contraction. Alpha is not conjugated. conjb: bool, optional If True, conjugate the second tensor `b` before contraction. Beta is not conjugated. allow_partial_trace : bool, optional If True, handle redundant indices in subscripts for `a` and `b` by doing partial trace before contraction. complex_real_contractions : bool, default True If True, handle contractions between complex and real tensors by performing separate contractions for the real and imaginary parts of the complex tensor. alpha and beta must be real in this case. out : array_like, optional Output tensor to store the result. Returns ------- ndarray Result of the tensor contraction. Examples -------- >>> import numpy as np >>> from pytblis import contract_complex_real >>> A = np.random.rand(3, 4, 5) >>> B = np.random.rand(4, 5, 6).astype(np.complex128) >>> result = contract_complex_real('ijk,jkl->il', A, B, alpha=2.0) """ if not complex_real_contractions: return contract_same_type( subscripts, a, b, alpha=alpha, beta=beta, out=out, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace, ) a = np.asarray(a) b = np.asarray(b) if out is not None: real_scalar_type = _check_tblis_types(a.real, b.real, out=out.real) else: real_scalar_type = _check_tblis_types(a.real, b.real) if real_scalar_type is None: raise TypeError("Input arrays and output array (if provided) must use the same IEEE floating point type.") if (np.iscomplexobj(a) and np.iscomplexobj(b)) or ((not np.iscomplexobj(a)) and (not np.iscomplexobj(b))): return contract_same_type(subscripts, a, b, alpha, beta, out, conja, conjb, allow_partial_trace) # exactly one of a or b is complex, so the result must be complex if out is not None and not np.iscomplexobj(out): msg = "Output array must have complex dtype when contracting a complex tensor with a real tensor." raise ValueError(msg) result_type = np.result_type(real_scalar_type, 1j).type # allocate complex output if not provided c_shape = contraction_result_shape(subscripts, a.shape, b.shape) if out is None: out = np.empty(c_shape, dtype=result_type, order=get_default_array_order()) assert beta == 0.0, "beta must be 0.0 if out is None" assert alpha.imag == 0.0, "alpha must be real when contracting a complex tensor with a real tensor." assert beta.imag == 0.0, "beta must be real when contracting a complex tensor with a real tensor." if c_shape: if np.iscomplexobj(a) and not np.iscomplexobj(b): imag_fac = -1.0 if conja else 1.0 a_real = a.real contract( subscripts, a_real, b, alpha=alpha, beta=beta, out=out.real, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace, ) a_imag = a.imag contract( subscripts, a_imag, b, alpha=alpha * imag_fac, beta=beta, out=out.imag, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace, ) else: imag_fac = -1.0 if conjb else 1.0 b_real = b.real contract( subscripts, a, b_real, alpha=alpha, beta=beta, out=out.real, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace, ) b_imag = b.imag contract( subscripts, a, b_imag, alpha=alpha * imag_fac, beta=beta, out=out.imag, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace, ) # handle scalar output # c_shape == () elif np.iscomplexobj(a) and not np.iscomplexobj(b): imag_fac = -1.0 if conja else 1.0 out = ( beta * out + alpha * contract(subscripts, a.real, b, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace) + 1j * alpha * imag_fac * contract(subscripts, a.imag, b, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace) ) else: imag_fac = -1.0 if conjb else 1.0 out = ( beta * out + alpha * contract(subscripts, a, b.real, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace) + 1j * alpha * imag_fac * contract(subscripts, a, b.imag, conja=conja, conjb=conjb, allow_partial_trace=allow_partial_trace) ) return out
[docs] def ascontiguousarray(a): """Parallel transpose the input to C-contiguous layout. Parameters ---------- a : array_like Input array. Returns ------- ndarray Contiguous array. """ a = np.asarray(a) if a.flags.c_contiguous: return a if not _check_strides(a): warnings.warn( f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}. Falling back to numpy ascontiguousarray.", stacklevel=2, ) return np.ascontiguousarray(a) if a.dtype.type not in _accepted_types: warnings.warn( "TBLIS only supports float32, float64, complex64, and complex128. Falling back to numpy ascontiguousarray.", stacklevel=2, ) return np.ascontiguousarray(a) out = np.empty(a.shape, dtype=a.dtype, order="C") assert len(a.shape) < len(_valid_labels), ( f"a.ndim is {len(a.shape)}, but only {len(_valid_labels)} labels are valid." ) a_inds = _valid_labels[: len(a.shape)] a_inds = "".join(a_inds) add(a, out, a_inds, a_inds, alpha=1.0, beta=0.0) return out
[docs] def asfortranarray(a): """Parallel transpose the input to F-contiguous layout. Parameters ---------- a : array_like Input array. Returns ------- ndarray Fortran-order array. """ a = np.asarray(a) if a.flags.f_contiguous: return a if not _check_strides(a): warnings.warn( f"Input tensor of shape {a.shape} has non-positive strides: {a.strides}. Falling back to numpy asfortranarray.", stacklevel=2, ) return np.asfortranarray(a) if a.dtype.type not in _accepted_types: warnings.warn( "TBLIS only supports float32, float64, complex64, and complex128. Falling back to numpy asfortranarray.", stacklevel=2, ) return np.asfortranarray(a) out = np.empty(a.shape, dtype=a.dtype, order="F") assert len(a.shape) < len(_valid_labels), ( f"a.ndim is {len(a.shape)}, but only {len(_valid_labels)} labels are valid." ) a_inds = _valid_labels[: len(a.shape)] a_inds = "".join(a_inds) add(a, out, a_inds, a_inds, alpha=1.0, beta=0.0) return out