| 132 | raise ValueError(f"Unknown dataset: {name}") |
| 133 | |
| 134 | class ImageNetDataset: |
| 135 | def __init__(self, root='./data/imagenet', split='train', transform=None, download=False): |
| 136 | if transform is None: |
| 137 | if split == 'train': |
| 138 | transform = transforms.Compose([ |
| 139 | transforms.RandomResizedCrop(224), |
| 140 | transforms.RandomHorizontalFlip(), |
| 141 | transforms.ColorJitter(0.4, 0.4, 0.4), |
| 142 | transforms.ToTensor(), |
| 143 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 144 | ]) |
| 145 | else: |
| 146 | transform = transforms.Compose([ |
| 147 | transforms.Resize(256), |
| 148 | transforms.CenterCrop(224), |
| 149 | transforms.ToTensor(), |
| 150 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| 151 | ]) |
| 152 | |
| 153 | try: |
| 154 | self.dataset = datasets.ImageFolder(os.path.join(root, split), transform=transform) |
| 155 | except: |
| 156 | print(f"ImageNet not found at {root}. Please download manually from https://image-net.org/") |
| 157 | print("Expected structure: {root}/train/ and {root}/val/") |
| 158 | raise |
| 159 | |
| 160 | def __len__(self): |
| 161 | return len(self.dataset) |
| 162 | |
| 163 | def __getitem__(self, idx): |
| 164 | return self.dataset[idx] |
| 165 | |
| 166 | class TinyImageNetDataset: |
| 167 | def __init__(self, root='./data', train=True, transform=None, download=True): |