| 32 | return out |
| 33 | |
| 34 | class ResNet(nn.Module): |
| 35 | def __init__(self, block, layers, num_classes=1000, in_channels=3): |
| 36 | super().__init__() |
| 37 | self.in_channels = 64 |
| 38 | |
| 39 | self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) |
| 40 | self.bn1 = nn.BatchNorm2d(64) |
| 41 | self.relu = nn.ReLU(inplace=True) |
| 42 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
| 43 | |
| 44 | self.layer1 = self._make_layer(block, 64, layers[0]) |
| 45 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) |
| 46 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) |
| 47 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) |
| 48 | |
| 49 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
| 50 | self.fc = nn.Linear(512, num_classes) |
| 51 | |
| 52 | def _make_layer(self, block, out_channels, blocks, stride=1): |
| 53 | downsample = None |
| 54 | if stride != 1 or self.in_channels != out_channels: |
| 55 | downsample = nn.Sequential( |
| 56 | nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False), |
| 57 | nn.BatchNorm2d(out_channels) |
| 58 | ) |
| 59 | |
| 60 | layers = [] |
| 61 | layers.append(block(self.in_channels, out_channels, stride, downsample)) |
| 62 | self.in_channels = out_channels |
| 63 | |
| 64 | for _ in range(1, blocks): |
| 65 | layers.append(block(out_channels, out_channels)) |
| 66 | |
| 67 | return nn.Sequential(*layers) |
| 68 | |
| 69 | def forward(self, x): |
| 70 | x = self.conv1(x) |
| 71 | x = self.bn1(x) |
| 72 | x = self.relu(x) |
| 73 | x = self.maxpool(x) |
| 74 | |
| 75 | x = self.layer1(x) |
| 76 | x = self.layer2(x) |
| 77 | x = self.layer3(x) |
| 78 | x = self.layer4(x) |
| 79 | |
| 80 | x = self.avgpool(x) |
| 81 | x = torch.flatten(x, 1) |
| 82 | x = self.fc(x) |
| 83 | |
| 84 | return x |
| 85 | |
| 86 | class EfficientNetBlock(nn.Module): |
| 87 | def __init__(self, in_channels, out_channels, kernel_size, stride, expand_ratio, se_ratio=0.25): |