(tensors, name)
| 182 | |
| 183 | |
| 184 | def nested_xla_mesh_reduce(tensors, name): |
| 185 | if is_torch_xla_available(): |
| 186 | import torch_xla.core.xla_model as xm |
| 187 | |
| 188 | if isinstance(tensors, (list, tuple)): |
| 189 | return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors)) |
| 190 | if isinstance(tensors, Mapping): |
| 191 | return type(tensors)( |
| 192 | {k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())} |
| 193 | ) |
| 194 | |
| 195 | tensors = atleast_1d(tensors) |
| 196 | return xm.mesh_reduce(name, tensors, torch.cat) |
| 197 | else: |
| 198 | raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`") |
| 199 | |
| 200 | |
| 201 | def distributed_concat(tensor: Any, num_total_examples: int | None = None) -> Any: |
no test coverage detected