In [1]:
from typing import List

import torch
from torch import nn
from torchinfo import summary
import pytorch_lightning as pl

import utils.data as data

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_configs: List[int]):
        super(ConvBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_configs[0], (3, 3), padding=1)
        self.conv2 = nn.Conv2d(out_configs[0], out_configs[1], (3, 3), padding=1)

        if len(out_configs) == 3:
            self.conv3 = nn.Conv2d(out_configs[1], out_configs[2], (3, 3), padding=1)
        else:
            self.conv3 = None

        self.pool = nn.MaxPool2d((2, 2), padding=1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))

        if self.conv3 is not None:
            x = torch.relu(self.conv3(x))

        out = self.pool(x)

        return out

In [3]:
class VGG16(pl.LightningModule):
    def __init__(self, in_channels: int, num_classes: int):
        super(VGG16, self).__init__()

        self.cb1 = ConvBlock(in_channels, [64, 64])
        self.cb2 = ConvBlock(64, [128, 128])
        self.cb3 = ConvBlock(128, [256, 256, 256])
        self.cb4 = ConvBlock(256, [512, 512, 512])
        self.cb5 = ConvBlock(512, [512, 512, 512])

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(512 * 2 * 2, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, num_classes),
        )

        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.cb1(x)
        x = self.cb2(x)
        x = self.cb3(x)
        x = self.cb4(x)
        x = self.cb5(x)

        return self.fc(x)

    def training_step(self, xb, batch_idx):
        inp, labels = xb
        out = self(inp)

        return self.loss(out, labels)

    def validation_step(self, xb, batch_idx):
        inp, labels = xb
        out = self(inp)

        labels_hat = torch.argmax(out, dim=1)
        val_acc = torch.sum(labels == labels_hat).item() / (len(labels) * 1.0)

        self.log("val_loss", self.loss(out, labels))
        self.log("val_acc", val_acc)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=2e-4)

# run

In [4]:
num_classes = 10
in_channels = 1
epochs = 3
model = VGG16(in_channels, num_classes)

print(summary(model, input_size=(2, in_channels, 28, 28)))

trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=epochs,
    logger=pl.loggers.TensorBoardLogger("logs/", name="vgg", version=0),
)

trainer.fit(model, train_dataloader=data.train_dl, val_dataloaders=data.val_dl)

███▉ | 978/1095 [06:28<00:46,  2.52it/s, loss=2.3, v_num=0]
Validating:  25%|██▌       | 40/157 [00:04<00:12,  9.54it/s][A
Epoch 1:  89%|████████▉ | 980/1095 [06:28<00:45,  2.52it/s, loss=2.3, v_num=0]
Validating:  27%|██▋       | 42/157 [00:04<00:12,  9.50it/s][A
Epoch 1:  90%|████████▉ | 982/1095 [06:28<00:44,  2.52it/s, loss=2.3, v_num=0]
Validating:  28%|██▊       | 44/157 [00:04<00:11,  9.54it/s][A
Epoch 1:  90%|████████▉ | 984/1095 [06:29<00:43,  2.53it/s, loss=2.3, v_num=0]
Validating:  29%|██▉       | 46/157 [00:05<00:11,  9.54it/s][A
Epoch 1:  90%|█████████ | 986/1095 [06:29<00:43,  2.53it/s, loss=2.3, v_num=0]
Validating:  31%|███       | 48/157 [00:05<00:11,  9.52it/s][A
Epoch 1:  90%|█████████ | 988/1095 [06:29<00:42,  2.54it/s, loss=2.3, v_num=0]
Validating:  32%|███▏      | 50/157 [00:05<00:11,  9.55it/s][A
Epoch 1:  90%|█████████ | 990/1095 [06:29<00:41,  2.54it/s, loss=2.3, v_num=0]
Validating:  33%|███▎      | 52/157 [00:05<00:11,  9.46it/s][A
Epoch 1:  91%|████

1