Find the best order for three arrays and do the multiplication. For three arguments `_multi_dot_three` is approximately 15 times faster than `_multi_dot_matrix_chain_order`
(A, B, C, out=None)
| 2982 | |
| 2983 | |
| 2984 | def _multi_dot_three(A, B, C, out=None): |
| 2985 | """ |
| 2986 | Find the best order for three arrays and do the multiplication. |
| 2987 | |
| 2988 | For three arguments `_multi_dot_three` is approximately 15 times faster |
| 2989 | than `_multi_dot_matrix_chain_order` |
| 2990 | |
| 2991 | """ |
| 2992 | a0, a1b0 = A.shape |
| 2993 | b1c0, c1 = C.shape |
| 2994 | # cost1 = cost((AB)C) = a0*a1b0*b1c0 + a0*b1c0*c1 |
| 2995 | cost1 = a0 * b1c0 * (a1b0 + c1) |
| 2996 | # cost2 = cost(A(BC)) = a1b0*b1c0*c1 + a0*a1b0*c1 |
| 2997 | cost2 = a1b0 * c1 * (a0 + b1c0) |
| 2998 | |
| 2999 | if cost1 < cost2: |
| 3000 | return dot(dot(A, B), C, out=out) |
| 3001 | else: |
| 3002 | return dot(A, dot(B, C), out=out) |
| 3003 | |
| 3004 | |
| 3005 | def _multi_dot_matrix_chain_order(arrays, return_costs=False): |