MCPcopy
hub / github.com/numpy/numpy / bmm_einsum

Function bmm_einsum

numpy/_core/einsumfunc.py:1144–1231  ·  view source on GitHub ↗

Perform arbitrary pairwise einsums using only ``matmul``, or ``multiply`` if no contracted indices are involved (plus maybe single term ``einsum`` to prepare the terms individually). The logic for each is cached based on the equation and array shape, and each step is only performed if

(eq, a, b, out=None, **kwargs)

Source from the content-addressed store, hash-verified

1142
1143
1144def bmm_einsum(eq, a, b, out=None, **kwargs):
1145 """Perform arbitrary pairwise einsums using only ``matmul``, or
1146 ``multiply`` if no contracted indices are involved (plus maybe single term
1147 ``einsum`` to prepare the terms individually). The logic for each is cached
1148 based on the equation and array shape, and each step is only performed if
1149 necessary.
1150
1151 Parameters
1152 ----------
1153 eq : str
1154 The einsum equation.
1155 a : array_like
1156 The first array to contract.
1157 b : array_like
1158 The second array to contract.
1159
1160 Returns
1161 -------
1162 array_like
1163
1164 Notes
1165 -----
1166 A fuller description of this algorithm, and original source for this
1167 implementation, can be found at https://github.com/jcmgray/einsum_bmm.
1168 """
1169 (
1170 eq_a,
1171 eq_b,
1172 new_shape_a,
1173 new_shape_b,
1174 new_shape_ab,
1175 perm_ab,
1176 pure_multiplication,
1177 ) = _parse_eq_to_batch_matmul(eq, a.shape, b.shape)
1178
1179 # n.b. one could special case various cases to call c_einsum directly here
1180
1181 # need to handle `order` a little manually, since we do transpose
1182 # operations before and potentially after the ufunc calls
1183 output_order = _parse_output_order(
1184 kwargs.pop("order", "K"), a.flags.f_contiguous, b.flags.f_contiguous
1185 )
1186
1187 # prepare left
1188 if eq_a is not None:
1189 # diagonals, sums, and transpose
1190 a = c_einsum(eq_a, a)
1191 if new_shape_a is not None:
1192 a = reshape(a, new_shape_a)
1193
1194 # prepare right
1195 if eq_b is not None:
1196 # diagonals, sums, and transpose
1197 b = c_einsum(eq_b, b)
1198 if new_shape_b is not None:
1199 b = reshape(b, new_shape_b)
1200
1201 if pure_multiplication:

Callers 1

einsumFunction · 0.85

Calls 6

reshapeFunction · 0.90
_parse_output_orderFunction · 0.85
matmulFunction · 0.85
asanyarrayFunction · 0.85
multiplyFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…