(tensor)
| 1104 | return torch.cat([t.cpu() for t in all_tensors], dim=0) |
| 1105 | |
| 1106 | def smp_nested_concat(tensor): |
| 1107 | if isinstance(tensor, (list, tuple)): |
| 1108 | return type(tensor)(smp_nested_concat(t) for t in tensor) |
| 1109 | elif isinstance(tensor, dict): |
| 1110 | return type(tensor)({k: smp_nested_concat(v) for k, v in tensor.items()}) |
| 1111 | # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` |
| 1112 | # which is also the name of the decorator so Python is confused. |
| 1113 | return tensor.detach().concat().cpu() |
| 1114 | |
| 1115 | |
| 1116 | @dataclass |
no test coverage detected