Format torch's tensor in a pretty way to be shown 👀 in the test report.
(t, indent_level=0, sci_mode=None)
| 3887 | |
| 3888 | |
| 3889 | def _format_tensor(t, indent_level=0, sci_mode=None): |
| 3890 | class="st">""class="st">"Format torch&class="cm">#x27;s tensor in a pretty way to be shown 👀 in the test report."class="st">"" |
| 3891 | |
| 3892 | class="cm"># `torch.testing.assert_close` could accept python int/float numbers. |
| 3893 | if not isinstance(t, torch.Tensor): |
| 3894 | t = torch.tensor(t) |
| 3895 | |
| 3896 | class="cm"># Simply make the processing below simpler (not to handle both cases) |
| 3897 | is_scalar = False |
| 3898 | if t.ndim == 0: |
| 3899 | t = torch.tensor([t]) |
| 3900 | is_scalar = True |
| 3901 | |
| 3902 | class="cm"># For scalar or one-dimensional tensor, keep it as one-line. If there is only one element along any dimension except |
| 3903 | class="cm"># the last one, we also keep it as one-line. |
| 3904 | if t.ndim <= 1 or set(t.shape[0:-1]) == {1}: |
| 3905 | class="cm"># Use `detach` to remove `grad_fn=<...>`, and use `to(class="st">"cpu")` to remove `device=class="st">'...'` |
| 3906 | t = t.detach().to(class="st">"cpu") |
| 3907 | |
| 3908 | class="cm"># We work directly with the string representation instead the tensor itself |
| 3909 | t_str = str(t) |
| 3910 | |
| 3911 | class="cm"># remove `tensor( ... )` so keep only the content |
| 3912 | t_str = t_str.replace(class="st">"tensor(", class="st">"").replace(class="st">")", class="st">"") |
| 3913 | |
| 3914 | class="cm"># Sometimes there are extra spaces between `[` and the first digit of the first value (for alignment). |
| 3915 | class="cm"># For example `[[ 0.06, -0.51], [-0.76, -0.49]]`. It may have multiple consecutive spaces. |
| 3916 | class="cm"># Let's remove such extra spaces. |
| 3917 | while class="st">"[ " in t_str: |
| 3918 | t_str = t_str.replace(class="st">"[ ", class="st">"[") |
| 3919 | |
| 3920 | class="cm"># Put everything in a single line. We replace `\n` by a space ` ` so we still keep `,\n` as `, `. |
| 3921 | t_str = t_str.replace(class="st">"\n", class="st">" ") |
| 3922 | |
| 3923 | class="cm"># Remove repeated spaces (introduced by the previous step) |
| 3924 | while class="st">" " in t_str: |
| 3925 | t_str = t_str.replace(class="st">" ", class="st">" ") |
| 3926 | |
| 3927 | class="cm"># remove leading `[` and `]` for scalar tensor |
| 3928 | if is_scalar: |
| 3929 | t_str = t_str[1:-1] |
| 3930 | |
| 3931 | t_str = class="st">" " * 4 * indent_level + t_str |
| 3932 | |
| 3933 | return t_str |
| 3934 | |
| 3935 | class="cm"># Otherwise, we separate the representations of each element along an outer dimension by new lines (after a `,`). |
| 3936 | class="cm"># The representation of each element is obtained by calling this function recursively with current `indent_level`. |
| 3937 | else: |
| 3938 | t_str = str(t) |
| 3939 | |
| 3940 | class="cm"># (For the recursive calls should receive this value) |
| 3941 | if sci_mode is None: |
| 3942 | sci_mode = class="st">"e+" in t_str or class="st">"e-" in t_str |
| 3943 | |
| 3944 | class="cm"># Use the original content to determine the scientific mode to use. This is required as the representation of |
| 3945 | class="cm"># t[index] (computed below) maybe have different format regarding scientific notation. |
| 3946 | torch.set_printoptions(sci_mode=sci_mode) |
no test coverage detected