MCPcopy Index your code
hub / github.com/geekcomputers/Python / train

Method train

ML/src/python/neuralforge/trainer.py:137–189  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

135 return {'loss': avg_loss, 'accuracy': accuracy}
136
137 def train(self):
138 self.logger.info("Starting training...")
139 start_time = time.time()
140
141 for epoch in range(self.config.epochs):
142 self.current_epoch = epoch
143 epoch_start = time.time()
144
145 train_metrics = self.train_epoch()
146 val_metrics = self.validate()
147
148 if self.scheduler is not None:
149 if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
150 self.scheduler.step(val_metrics.get('loss', train_metrics['loss']))
151 else:
152 self.scheduler.step()
153
154 current_lr = self.optimizer.param_groups[0]['lr']
155 epoch_time = time.time() - epoch_start
156
157 self.logger.info(
158 f"Epoch {epoch + 1}/{self.config.epochs} | "
159 f"Train Loss: {train_metrics['loss']:.4f} | "
160 f"Train Acc: {train_metrics['accuracy']:.2f}% | "
161 f"Val Loss: {val_metrics.get('loss', 0):.4f} | "
162 f"Val Acc: {val_metrics.get('accuracy', 0):.2f}% | "
163 f"LR: {current_lr:.6f} | "
164 f"Time: {epoch_time:.2f}s"
165 )
166
167 self.metrics.update({
168 'epoch': epoch + 1,
169 'train_loss': train_metrics['loss'],
170 'train_acc': train_metrics['accuracy'],
171 'val_loss': val_metrics.get('loss', 0),
172 'val_acc': val_metrics.get('accuracy', 0),
173 'lr': current_lr,
174 'time': epoch_time
175 })
176
177 if (epoch + 1) % self.config.checkpoint_freq == 0:
178 self.save_checkpoint(f'checkpoint_epoch_{epoch + 1}.pt')
179
180 if val_metrics and val_metrics['loss'] < self.best_val_loss:
181 self.best_val_loss = val_metrics['loss']
182 self.save_checkpoint('best_model.pt')
183 self.logger.info(f"New best model saved with val_loss: {self.best_val_loss:.4f}")
184
185 total_time = time.time() - start_time
186 self.logger.info(f"Training completed in {total_time / 3600:.2f} hours")
187
188 self.save_checkpoint('final_model.pt')
189 self.metrics.save(os.path.join(self.config.log_dir, 'metrics.json'))
190
191 def save_checkpoint(self, filename: str):
192 checkpoint_path = os.path.join(self.config.model_dir, filename)

Callers 7

mainFunction · 0.95
mainFunction · 0.95
mainFunction · 0.95
mainFunction · 0.95
train_epochMethod · 0.45
_quick_evaluateMethod · 0.45
_full_evaluateMethod · 0.45

Calls 9

train_epochMethod · 0.95
validateMethod · 0.95
save_checkpointMethod · 0.95
infoMethod · 0.80
timeMethod · 0.80
stepMethod · 0.45
getMethod · 0.45
updateMethod · 0.45
saveMethod · 0.45

Tested by

no test coverage detected