()
| 13 | from src.python.neuralforge.optim.schedulers import CosineAnnealingWarmRestarts |
| 14 | |
| 15 | def main(): |
| 16 | print("Training ResNet18 on CIFAR-10") |
| 17 | |
| 18 | config = Config() |
| 19 | config.batch_size = 128 |
| 20 | config.epochs = 100 |
| 21 | config.learning_rate = 0.001 |
| 22 | config.num_classes = 10 |
| 23 | config.image_size = 32 |
| 24 | config.model_name = "resnet18_cifar10" |
| 25 | config.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| 26 | |
| 27 | print(f"Downloading CIFAR-10 dataset...") |
| 28 | train_dataset = get_dataset('cifar10', root='./data', train=True, download=True) |
| 29 | val_dataset = get_dataset('cifar10', root='./data', train=False, download=True) |
| 30 | |
| 31 | print(f"Train: {len(train_dataset)} samples") |
| 32 | print(f"Val: {len(val_dataset)} samples") |
| 33 | |
| 34 | loader_builder = DataLoaderBuilder(config) |
| 35 | train_loader = loader_builder.build_train_loader(train_dataset) |
| 36 | val_loader = loader_builder.build_val_loader(val_dataset) |
| 37 | |
| 38 | model = ResNet18(num_classes=10, in_channels=3) |
| 39 | print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters") |
| 40 | |
| 41 | criterion = nn.CrossEntropyLoss() |
| 42 | optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01) |
| 43 | scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2) |
| 44 | |
| 45 | trainer = Trainer( |
| 46 | model=model, |
| 47 | train_loader=train_loader, |
| 48 | val_loader=val_loader, |
| 49 | optimizer=optimizer, |
| 50 | criterion=criterion, |
| 51 | config=config, |
| 52 | scheduler=scheduler |
| 53 | ) |
| 54 | |
| 55 | print("Starting training...") |
| 56 | trainer.train() |
| 57 | |
| 58 | print(f"\nTraining completed!") |
| 59 | print(f"Best validation loss: {trainer.best_val_loss:.4f}") |
| 60 | print(f"Model saved to: ./models/best_model.pt") |
| 61 | print(f"\nTest the model:") |
| 62 | print(f" python tests/test_model.py --dataset cifar10 --mode interactive") |
| 63 | |
| 64 | if __name__ == '__main__': |
| 65 | main() |
no test coverage detected