Given array shapes, return the resulting shape and slices prefixes. These help in nested concatenation. Returns ------- shape: tuple of int This tuple satisfies:: shape, _ = _concatenate_shapes([arr.shape for shape in arrs], axis) shape == concatena
(shapes, axis)
| 636 | |
| 637 | |
| 638 | def _concatenate_shapes(shapes, axis): |
| 639 | """Given array shapes, return the resulting shape and slices prefixes. |
| 640 | |
| 641 | These help in nested concatenation. |
| 642 | |
| 643 | Returns |
| 644 | ------- |
| 645 | shape: tuple of int |
| 646 | This tuple satisfies:: |
| 647 | |
| 648 | shape, _ = _concatenate_shapes([arr.shape for shape in arrs], axis) |
| 649 | shape == concatenate(arrs, axis).shape |
| 650 | |
| 651 | slice_prefixes: tuple of (slice(start, end), ) |
| 652 | For a list of arrays being concatenated, this returns the slice |
| 653 | in the larger array at axis that needs to be sliced into. |
| 654 | |
| 655 | For example, the following holds:: |
| 656 | |
| 657 | ret = concatenate([a, b, c], axis) |
| 658 | _, (sl_a, sl_b, sl_c) = concatenate_slices([a, b, c], axis) |
| 659 | |
| 660 | ret[(slice(None),) * axis + sl_a] == a |
| 661 | ret[(slice(None),) * axis + sl_b] == b |
| 662 | ret[(slice(None),) * axis + sl_c] == c |
| 663 | |
| 664 | These are called slice prefixes since they are used in the recursive |
| 665 | blocking algorithm to compute the left-most slices during the |
| 666 | recursion. Therefore, they must be prepended to rest of the slice |
| 667 | that was computed deeper in the recursion. |
| 668 | |
| 669 | These are returned as tuples to ensure that they can quickly be added |
| 670 | to existing slice tuple without creating a new tuple every time. |
| 671 | |
| 672 | """ |
| 673 | # Cache a result that will be reused. |
| 674 | shape_at_axis = [shape[axis] for shape in shapes] |
| 675 | |
| 676 | # Take a shape, any shape |
| 677 | first_shape = shapes[0] |
| 678 | first_shape_pre = first_shape[:axis] |
| 679 | first_shape_post = first_shape[axis + 1:] |
| 680 | |
| 681 | if any(shape[:axis] != first_shape_pre or |
| 682 | shape[axis + 1:] != first_shape_post for shape in shapes): |
| 683 | raise ValueError( |
| 684 | f'Mismatched array shapes in block along axis {axis}.') |
| 685 | |
| 686 | shape = (first_shape_pre + (sum(shape_at_axis),) + first_shape[axis + 1:]) |
| 687 | |
| 688 | offsets_at_axis = _accumulate(shape_at_axis) |
| 689 | slice_prefixes = [(slice(start, end),) |
| 690 | for start, end in zip([0] + offsets_at_axis, |
| 691 | offsets_at_axis)] |
| 692 | return shape, slice_prefixes |
| 693 | |
| 694 | |
| 695 | def _block_info_recursion(arrays, max_depth, result_ndim, depth=0): |
no test coverage detected
searching dependent graphs…