Get the sequence parallel size
(self)
| 2397 | return self._train_batch_size * args.gradient_accumulation_steps * dp_world_size |
| 2398 | |
| 2399 | def get_sp_size(self) -> int: |
| 2400 | """Get the sequence parallel size""" |
| 2401 | if getattr(self.accelerator, "parallelism_config", None) is None: |
| 2402 | return 1 |
| 2403 | else: |
| 2404 | pc = self.accelerator.parallelism_config |
| 2405 | return pc.sp_size |
| 2406 | |
| 2407 | def get_cp_size(self) -> int: |
| 2408 | """Get the context parallel size""" |
no outgoing calls
no test coverage detected