Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
(tensors)
| 173 | |
| 174 | |
| 175 | def nested_detach(tensors): |
| 176 | "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)." |
| 177 | if isinstance(tensors, (list, tuple)): |
| 178 | return type(tensors)(nested_detach(t) for t in tensors) |
| 179 | elif isinstance(tensors, Mapping): |
| 180 | return type(tensors)({k: nested_detach(t) for k, t in tensors.items()}) |
| 181 | return tensors.detach() if isinstance(tensors, torch.Tensor) else tensors |
| 182 | |
| 183 | |
| 184 | def nested_xla_mesh_reduce(tensors, name): |
no test coverage detected