MCPcopy Index your code
hub / github.com/geekcomputers/Python / main

Function main

ML/examples/train_cifar10.py:15–62  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

13from src.python.neuralforge.optim.schedulers import CosineAnnealingWarmRestarts
14
15def main():
16 print("Training ResNet18 on CIFAR-10")
17
18 config = Config()
19 config.batch_size = 128
20 config.epochs = 100
21 config.learning_rate = 0.001
22 config.num_classes = 10
23 config.image_size = 32
24 config.model_name = "resnet18_cifar10"
25 config.device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
27 print(f"Downloading CIFAR-10 dataset...")
28 train_dataset = get_dataset('cifar10', root='./data', train=True, download=True)
29 val_dataset = get_dataset('cifar10', root='./data', train=False, download=True)
30
31 print(f"Train: {len(train_dataset)} samples")
32 print(f"Val: {len(val_dataset)} samples")
33
34 loader_builder = DataLoaderBuilder(config)
35 train_loader = loader_builder.build_train_loader(train_dataset)
36 val_loader = loader_builder.build_val_loader(val_dataset)
37
38 model = ResNet18(num_classes=10, in_channels=3)
39 print(f"Model: {sum(p.numel() for p in model.parameters()):,} parameters")
40
41 criterion = nn.CrossEntropyLoss()
42 optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
43 scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
44
45 trainer = Trainer(
46 model=model,
47 train_loader=train_loader,
48 val_loader=val_loader,
49 optimizer=optimizer,
50 criterion=criterion,
51 config=config,
52 scheduler=scheduler
53 )
54
55 print("Starting training...")
56 trainer.train()
57
58 print(f"\nTraining completed!")
59 print(f"Best validation loss: {trainer.best_val_loss:.4f}")
60 print(f"Model saved to: ./models/best_model.pt")
61 print(f"\nTest the model:")
62 print(f" python tests/test_model.py --dataset cifar10 --mode interactive")
63
64if __name__ == '__main__':
65 main()

Callers 1

train_cifar10.pyFile · 0.70

Calls 10

build_train_loaderMethod · 0.95
build_val_loaderMethod · 0.95
trainMethod · 0.95
ConfigClass · 0.90
get_datasetFunction · 0.90
DataLoaderBuilderClass · 0.90
ResNet18Function · 0.90
AdamWClass · 0.90
TrainerClass · 0.90

Tested by

no test coverage detected