Split an array into multiple sub-arrays. Please refer to the ``split`` documentation. The only difference between these functions is that ``array_split`` allows `indices_or_sections` to be an integer that does *not* equally divide the axis. For an array of length l that should
(ary, indices_or_sections, axis=0)
| 718 | |
| 719 | @array_function_dispatch(_array_split_dispatcher) |
| 720 | def array_split(ary, indices_or_sections, axis=0): |
| 721 | """ |
| 722 | Split an array into multiple sub-arrays. |
| 723 | |
| 724 | Please refer to the ``split`` documentation. The only difference |
| 725 | between these functions is that ``array_split`` allows |
| 726 | `indices_or_sections` to be an integer that does *not* equally |
| 727 | divide the axis. For an array of length l that should be split |
| 728 | into n sections, it returns l % n sub-arrays of size l//n + 1 |
| 729 | and the rest of size l//n. |
| 730 | |
| 731 | See Also |
| 732 | -------- |
| 733 | split : Split array into multiple sub-arrays of equal size. |
| 734 | |
| 735 | Examples |
| 736 | -------- |
| 737 | >>> import numpy as np |
| 738 | >>> x = np.arange(8.0) |
| 739 | >>> np.array_split(x, 3) |
| 740 | [array([0., 1., 2.]), array([3., 4., 5.]), array([6., 7.])] |
| 741 | |
| 742 | >>> x = np.arange(9) |
| 743 | >>> np.array_split(x, 4) |
| 744 | [array([0, 1, 2]), array([3, 4]), array([5, 6]), array([7, 8])] |
| 745 | |
| 746 | """ |
| 747 | try: |
| 748 | Ntotal = ary.shape[axis] |
| 749 | except AttributeError: |
| 750 | Ntotal = len(ary) |
| 751 | try: |
| 752 | # handle array case. |
| 753 | Nsections = len(indices_or_sections) + 1 |
| 754 | div_points = [0] + list(indices_or_sections) + [Ntotal] |
| 755 | except TypeError: |
| 756 | # indices_or_sections is a scalar, not an array. |
| 757 | Nsections = int(indices_or_sections) |
| 758 | if Nsections <= 0: |
| 759 | raise ValueError('number sections must be larger than 0.') from None |
| 760 | Neach_section, extras = divmod(Ntotal, Nsections) |
| 761 | section_sizes = ([0] + |
| 762 | extras * [Neach_section + 1] + |
| 763 | (Nsections - extras) * [Neach_section]) |
| 764 | div_points = _nx.array(section_sizes, dtype=_nx.intp).cumsum() |
| 765 | |
| 766 | sub_arys = [] |
| 767 | sary = _nx.swapaxes(ary, axis, 0) |
| 768 | for i in range(Nsections): |
| 769 | st = div_points[i] |
| 770 | end = div_points[i + 1] |
| 771 | sub_arys.append(_nx.swapaxes(sary[st:end], axis, 0)) |
| 772 | |
| 773 | return sub_arys |
| 774 | |
| 775 | |
| 776 | def _split_dispatcher(ary, indices_or_sections, axis=None): |
searching dependent graphs…