MCPcopy
hub / github.com/huggingface/transformers / nested_xla_mesh_reduce

Function nested_xla_mesh_reduce

src/transformers/trainer_pt_utils.py:184–198  ·  view source on GitHub ↗
(tensors, name)

Source from the content-addressed store, hash-verified

182
183
184def nested_xla_mesh_reduce(tensors, name):
185 if is_torch_xla_available():
186 import torch_xla.core.xla_model as xm
187
188 if isinstance(tensors, (list, tuple)):
189 return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
190 if isinstance(tensors, Mapping):
191 return type(tensors)(
192 {k: nested_xla_mesh_reduce(t, f"{name}_{i}") for i, (k, t) in enumerate(tensors.items())}
193 )
194
195 tensors = atleast_1d(tensors)
196 return xm.mesh_reduce(name, tensors, torch.cat)
197 else:
198 raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
199
200
201def distributed_concat(tensor: Any, num_total_examples: int | None = None) -> Any:

Callers 1

nested_gatherFunction · 0.85

Calls 3

is_torch_xla_availableFunction · 0.85
atleast_1dFunction · 0.85
itemsMethod · 0.45

Tested by

no test coverage detected