Perform arbitrary pairwise einsums using only ``matmul``, or ``multiply`` if no contracted indices are involved (plus maybe single term ``einsum`` to prepare the terms individually). The logic for each is cached based on the equation and array shape, and each step is only performed if
(eq, a, b, out=None, **kwargs)
| 1142 | |
| 1143 | |
| 1144 | def bmm_einsum(eq, a, b, out=None, **kwargs): |
| 1145 | """Perform arbitrary pairwise einsums using only ``matmul``, or |
| 1146 | ``multiply`` if no contracted indices are involved (plus maybe single term |
| 1147 | ``einsum`` to prepare the terms individually). The logic for each is cached |
| 1148 | based on the equation and array shape, and each step is only performed if |
| 1149 | necessary. |
| 1150 | |
| 1151 | Parameters |
| 1152 | ---------- |
| 1153 | eq : str |
| 1154 | The einsum equation. |
| 1155 | a : array_like |
| 1156 | The first array to contract. |
| 1157 | b : array_like |
| 1158 | The second array to contract. |
| 1159 | |
| 1160 | Returns |
| 1161 | ------- |
| 1162 | array_like |
| 1163 | |
| 1164 | Notes |
| 1165 | ----- |
| 1166 | A fuller description of this algorithm, and original source for this |
| 1167 | implementation, can be found at https://github.com/jcmgray/einsum_bmm. |
| 1168 | """ |
| 1169 | ( |
| 1170 | eq_a, |
| 1171 | eq_b, |
| 1172 | new_shape_a, |
| 1173 | new_shape_b, |
| 1174 | new_shape_ab, |
| 1175 | perm_ab, |
| 1176 | pure_multiplication, |
| 1177 | ) = _parse_eq_to_batch_matmul(eq, a.shape, b.shape) |
| 1178 | |
| 1179 | # n.b. one could special case various cases to call c_einsum directly here |
| 1180 | |
| 1181 | # need to handle `order` a little manually, since we do transpose |
| 1182 | # operations before and potentially after the ufunc calls |
| 1183 | output_order = _parse_output_order( |
| 1184 | kwargs.pop("order", "K"), a.flags.f_contiguous, b.flags.f_contiguous |
| 1185 | ) |
| 1186 | |
| 1187 | # prepare left |
| 1188 | if eq_a is not None: |
| 1189 | # diagonals, sums, and transpose |
| 1190 | a = c_einsum(eq_a, a) |
| 1191 | if new_shape_a is not None: |
| 1192 | a = reshape(a, new_shape_a) |
| 1193 | |
| 1194 | # prepare right |
| 1195 | if eq_b is not None: |
| 1196 | # diagonals, sums, and transpose |
| 1197 | b = c_einsum(eq_b, b) |
| 1198 | if new_shape_b is not None: |
| 1199 | b = reshape(b, new_shape_b) |
| 1200 | |
| 1201 | if pure_multiplication: |
no test coverage detected
searching dependent graphs…