start tracking for the caller's stage
(self)
| 695 | break |
| 696 | |
| 697 | def start(self): |
| 698 | """start tracking for the caller's stage""" |
| 699 | if self.skip_memory_metrics: |
| 700 | return |
| 701 | |
| 702 | stage = self.derive_stage() |
| 703 | # deal with nested calls of eval during train - simply ignore those |
| 704 | if self.cur_stage is not None and self.cur_stage != stage: |
| 705 | return |
| 706 | |
| 707 | self.cur_stage = stage |
| 708 | |
| 709 | gc.collect() |
| 710 | |
| 711 | if self.torch is not None: |
| 712 | if torch.cuda.is_available(): |
| 713 | self.torch.cuda.reset_peak_memory_stats() |
| 714 | self.torch.cuda.empty_cache() |
| 715 | elif is_torch_mlu_available(): |
| 716 | self.torch.mlu.reset_peak_memory_stats() |
| 717 | self.torch.mlu.empty_cache() |
| 718 | elif is_torch_musa_available(): |
| 719 | self.torch.musa.reset_peak_memory_stats() |
| 720 | self.torch.musa.empty_cache() |
| 721 | elif is_torch_xpu_available(): |
| 722 | self.torch.xpu.reset_peak_memory_stats() |
| 723 | self.torch.xpu.empty_cache() |
| 724 | elif is_torch_npu_available(): |
| 725 | self.torch.npu.reset_peak_memory_stats() |
| 726 | self.torch.npu.empty_cache() |
| 727 | elif is_torch_hpu_available(): |
| 728 | self.torch.hpu.reset_peak_memory_stats() |
| 729 | # not available on hpu as it reserves all device memory for the current process |
| 730 | # self.torch.hpu.empty_cache() |
| 731 | elif is_torch_mps_available(): |
| 732 | self.torch.mps.empty_cache() |
| 733 | |
| 734 | # gpu |
| 735 | if self.torch is not None: |
| 736 | if torch.cuda.is_available(): |
| 737 | self.gpu_mem_used_at_start = self.torch.cuda.memory_allocated() |
| 738 | elif is_torch_mlu_available(): |
| 739 | self.gpu_mem_used_at_start = self.torch.mlu.memory_allocated() |
| 740 | elif is_torch_musa_available(): |
| 741 | self.gpu_mem_used_at_start = self.torch.musa.memory_allocated() |
| 742 | elif is_torch_xpu_available(): |
| 743 | self.gpu_mem_used_at_start = self.torch.xpu.memory_allocated() |
| 744 | elif is_torch_npu_available(): |
| 745 | self.gpu_mem_used_at_start = self.torch.npu.memory_allocated() |
| 746 | elif is_torch_hpu_available(): |
| 747 | self.gpu_mem_used_at_start = self.torch.hpu.memory_allocated() |
| 748 | elif is_torch_mps_available(): |
| 749 | self.gpu_mem_used_at_start = self.torch.mps.current_allocated_memory() |
| 750 | |
| 751 | # cpu |
| 752 | self.cpu_mem_used_at_start = self.cpu_mem_used() |
| 753 | |
| 754 | self.peak_monitoring = True |