{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "pythonjvsc74a57bd01811777d830f1030e31060b656d737abbce627438427bf59fb8b24dc91025654",
   "display_name": "Python 3.8.5 64-bit ('.env': venv)"
  },
  "metadata": {
   "interpreter": {
    "hash": "1811777d830f1030e31060b656d737abbce627438427bf59fb8b24dc91025654"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "from torchinfo import summary\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "import utils.data as data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvBlock(nn.Module):\n",
    "    def __init__(self, in_channels: int, out_configs: List[int]):\n",
    "        super(ConvBlock, self).__init__()\n",
    "\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_configs[0], (3, 3), padding=1)\n",
    "        self.conv2 = nn.Conv2d(out_configs[0], out_configs[1], (3, 3), padding=1)\n",
    "\n",
    "        if len(out_configs) == 3:\n",
    "            self.conv3 = nn.Conv2d(out_configs[1], out_configs[2], (3, 3), padding=1)\n",
    "        else:\n",
    "            self.conv3 = None\n",
    "\n",
    "        self.pool = nn.MaxPool2d((2, 2), padding=1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.conv1(x))\n",
    "        x = torch.relu(self.conv2(x))\n",
    "\n",
    "        if self.conv3 is not None:\n",
    "            x = torch.relu(self.conv3(x))\n",
    "\n",
    "        out = self.pool(x)\n",
    "\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VGG16(pl.LightningModule):\n",
    "    def __init__(self, in_channels: int, num_classes: int):\n",
    "        super(VGG16, self).__init__()\n",
    "\n",
    "        self.cb1 = ConvBlock(in_channels, [64, 64])\n",
    "        self.cb2 = ConvBlock(64, [128, 128])\n",
    "        self.cb3 = ConvBlock(128, [256, 256, 256])\n",
    "        self.cb4 = ConvBlock(256, [512, 512, 512])\n",
    "        self.cb5 = ConvBlock(512, [512, 512, 512])\n",
    "\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Flatten(),\n",
    "            nn.Linear(512 * 2 * 2, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, num_classes),\n",
    "        )\n",
    "\n",
    "        self.loss = nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.cb1(x)\n",
    "        x = self.cb2(x)\n",
    "        x = self.cb3(x)\n",
    "        x = self.cb4(x)\n",
    "        x = self.cb5(x)\n",
    "\n",
    "        return self.fc(x)\n",
    "\n",
    "    def training_step(self, xb, batch_idx):\n",
    "        inp, labels = xb\n",
    "        out = self(inp)\n",
    "\n",
    "        return self.loss(out, labels)\n",
    "\n",
    "    def validation_step(self, xb, batch_idx):\n",
    "        inp, labels = xb\n",
    "        out = self(inp)\n",
    "\n",
    "        labels_hat = torch.argmax(out, dim=1)\n",
    "        val_acc = torch.sum(labels == labels_hat).item() / (len(labels) * 1.0)\n",
    "\n",
    "        self.log(\"val_loss\", self.loss(out, labels))\n",
    "        self.log(\"val_acc\", val_acc)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.Adam(self.parameters(), lr=2e-4)"
   ]
  },
  {
   "source": [
    "# run"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": [
     "outputPrepend"
    ]
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": [
      "███▉ | 978/1095 [06:28<00:46,  2.52it/s, loss=2.3, v_num=0]\n",
      "Validating:  25%|██▌       | 40/157 [00:04<00:12,  9.54it/s]\u001b[A\n",
      "Epoch 1:  89%|████████▉ | 980/1095 [06:28<00:45,  2.52it/s, loss=2.3, v_num=0]\n",
      "Validating:  27%|██▋       | 42/157 [00:04<00:12,  9.50it/s]\u001b[A\n",
      "Epoch 1:  90%|████████▉ | 982/1095 [06:28<00:44,  2.52it/s, loss=2.3, v_num=0]\n",
      "Validating:  28%|██▊       | 44/157 [00:04<00:11,  9.54it/s]\u001b[A\n",
      "Epoch 1:  90%|████████▉ | 984/1095 [06:29<00:43,  2.53it/s, loss=2.3, v_num=0]\n",
      "Validating:  29%|██▉       | 46/157 [00:05<00:11,  9.54it/s]\u001b[A\n",
      "Epoch 1:  90%|█████████ | 986/1095 [06:29<00:43,  2.53it/s, loss=2.3, v_num=0]\n",
      "Validating:  31%|███       | 48/157 [00:05<00:11,  9.52it/s]\u001b[A\n",
      "Epoch 1:  90%|█████████ | 988/1095 [06:29<00:42,  2.54it/s, loss=2.3, v_num=0]\n",
      "Validating:  32%|███▏      | 50/157 [00:05<00:11,  9.55it/s]\u001b[A\n",
      "Epoch 1:  90%|█████████ | 990/1095 [06:29<00:41,  2.54it/s, loss=2.3, v_num=0]\n",
      "Validating:  33%|███▎      | 52/157 [00:05<00:11,  9.46it/s]\u001b[A\n",
      "Epoch 1:  91%|█████████ | 992/1095 [06:29<00:40,  2.54it/s, loss=2.3, v_num=0]\n",
      "Validating:  34%|███▍      | 54/157 [00:05<00:10,  9.52it/s]\u001b[A\n",
      "Epoch 1:  91%|█████████ | 994/1095 [06:30<00:39,  2.55it/s, loss=2.3, v_num=0]\n",
      "Validating:  36%|███▌      | 56/157 [00:06<00:10,  9.50it/s]\u001b[A\n",
      "Epoch 1:  91%|█████████ | 996/1095 [06:30<00:38,  2.55it/s, loss=2.3, v_num=0]\n",
      "Validating:  37%|███▋      | 58/157 [00:06<00:10,  9.42it/s]\u001b[A\n",
      "Epoch 1:  91%|█████████ | 998/1095 [06:30<00:37,  2.56it/s, loss=2.3, v_num=0]\n",
      "Validating:  38%|███▊      | 60/157 [00:06<00:10,  9.35it/s]\u001b[A\n",
      "Epoch 1:  91%|█████████▏| 1000/1095 [06:30<00:37,  2.56it/s, loss=2.3, v_num=0]\n",
      "Validating:  39%|███▉      | 62/157 [00:06<00:10,  9.44it/s]\u001b[A\n",
      "Epoch 1:  92%|█████████▏| 1002/1095 [06:31<00:36,  2.56it/s, loss=2.3, v_num=0]\n",
      "Validating:  41%|████      | 64/157 [00:07<00:09,  9.45it/s]\u001b[A\n",
      "Epoch 1:  92%|█████████▏| 1004/1095 [06:31<00:35,  2.57it/s, loss=2.3, v_num=0]\n",
      "Validating:  42%|████▏     | 66/157 [00:07<00:09,  9.48it/s]\u001b[A\n",
      "Epoch 1:  92%|█████████▏| 1006/1095 [06:31<00:34,  2.57it/s, loss=2.3, v_num=0]\n",
      "Validating:  43%|████▎     | 68/157 [00:07<00:09,  9.51it/s]\u001b[A\n",
      "Epoch 1:  92%|█████████▏| 1008/1095 [06:31<00:33,  2.57it/s, loss=2.3, v_num=0]\n",
      "Validating:  45%|████▍     | 70/157 [00:07<00:09,  9.44it/s]\u001b[A\n",
      "Epoch 1:  92%|█████████▏| 1010/1095 [06:31<00:32,  2.58it/s, loss=2.3, v_num=0]\n",
      "Validating:  46%|████▌     | 72/157 [00:07<00:08,  9.51it/s]\u001b[A\n",
      "Epoch 1:  92%|█████████▏| 1012/1095 [06:32<00:32,  2.58it/s, loss=2.3, v_num=0]\n",
      "Validating:  47%|████▋     | 74/157 [00:08<00:08,  9.52it/s]\u001b[A\n",
      "Epoch 1:  93%|█████████▎| 1014/1095 [06:32<00:31,  2.58it/s, loss=2.3, v_num=0]\n",
      "Validating:  48%|████▊     | 76/157 [00:08<00:08,  9.52it/s]\u001b[A\n",
      "Epoch 1:  93%|█████████▎| 1016/1095 [06:32<00:30,  2.59it/s, loss=2.3, v_num=0]\n",
      "Validating:  50%|████▉     | 78/157 [00:08<00:08,  9.55it/s]\u001b[A\n",
      "Epoch 1:  93%|█████████▎| 1018/1095 [06:32<00:29,  2.59it/s, loss=2.3, v_num=0]\n",
      "Validating:  51%|█████     | 80/157 [00:08<00:08,  9.50it/s]\u001b[A\n",
      "Epoch 1:  93%|█████████▎| 1020/1095 [06:32<00:28,  2.60it/s, loss=2.3, v_num=0]\n",
      "Validating:  52%|█████▏    | 82/157 [00:08<00:07,  9.55it/s]\u001b[A\n",
      "Epoch 1:  93%|█████████▎| 1022/1095 [06:33<00:28,  2.60it/s, loss=2.3, v_num=0]\n",
      "Validating:  54%|█████▎    | 84/157 [00:09<00:07,  9.57it/s]\u001b[A\n",
      "Epoch 1:  94%|█████████▎| 1024/1095 [06:33<00:27,  2.60it/s, loss=2.3, v_num=0]\n",
      "Validating:  55%|█████▍    | 86/157 [00:09<00:07,  9.56it/s]\u001b[A\n",
      "Epoch 1:  94%|█████████▎| 1026/1095 [06:33<00:26,  2.61it/s, loss=2.3, v_num=0]\n",
      "Validating:  56%|█████▌    | 88/157 [00:09<00:07,  9.56it/s]\u001b[A\n",
      "Epoch 1:  94%|█████████▍| 1028/1095 [06:33<00:25,  2.61it/s, loss=2.3, v_num=0]\n",
      "Validating:  57%|█████▋    | 90/157 [00:09<00:07,  9.48it/s]\u001b[A\n",
      "Epoch 1:  94%|█████████▍| 1030/1095 [06:33<00:24,  2.61it/s, loss=2.3, v_num=0]\n",
      "Validating:  59%|█████▊    | 92/157 [00:09<00:06,  9.51it/s]\u001b[A\n",
      "Epoch 1:  94%|█████████▍| 1032/1095 [06:34<00:24,  2.62it/s, loss=2.3, v_num=0]\n",
      "Validating:  60%|█████▉    | 94/157 [00:10<00:06,  9.54it/s]\u001b[A\n",
      "Epoch 1:  94%|█████████▍| 1034/1095 [06:34<00:23,  2.62it/s, loss=2.3, v_num=0]\n",
      "Validating:  61%|██████    | 96/157 [00:10<00:06,  9.55it/s]\u001b[A\n",
      "Epoch 1:  95%|█████████▍| 1036/1095 [06:34<00:22,  2.63it/s, loss=2.3, v_num=0]\n",
      "Validating:  62%|██████▏   | 98/157 [00:10<00:06,  9.56it/s]\u001b[A\n",
      "Epoch 1:  95%|█████████▍| 1038/1095 [06:34<00:21,  2.63it/s, loss=2.3, v_num=0]\n",
      "Validating:  64%|██████▎   | 100/157 [00:10<00:05,  9.55it/s]\u001b[A\n",
      "Epoch 1:  95%|█████████▍| 1040/1095 [06:35<00:20,  2.63it/s, loss=2.3, v_num=0]\n",
      "Validating:  65%|██████▍   | 102/157 [00:11<00:05,  9.55it/s]\u001b[A\n",
      "Epoch 1:  95%|█████████▌| 1042/1095 [06:35<00:20,  2.64it/s, loss=2.3, v_num=0]\n",
      "Validating:  66%|██████▌   | 104/157 [00:11<00:05,  9.58it/s]\u001b[A\n",
      "Epoch 1:  95%|█████████▌| 1044/1095 [06:35<00:19,  2.64it/s, loss=2.3, v_num=0]\n",
      "Validating:  68%|██████▊   | 106/157 [00:11<00:05,  9.57it/s]\u001b[A\n",
      "Epoch 1:  96%|█████████▌| 1046/1095 [06:35<00:18,  2.64it/s, loss=2.3, v_num=0]\n",
      "Validating:  69%|██████▉   | 108/157 [00:11<00:05,  9.57it/s]\u001b[A\n",
      "Epoch 1:  96%|█████████▌| 1048/1095 [06:35<00:17,  2.65it/s, loss=2.3, v_num=0]\n",
      "Validating:  70%|███████   | 110/157 [00:11<00:04,  9.55it/s]\u001b[A\n",
      "Epoch 1:  96%|█████████▌| 1050/1095 [06:36<00:16,  2.65it/s, loss=2.3, v_num=0]\n",
      "Validating:  71%|███████▏  | 112/157 [00:12<00:04,  9.47it/s]\u001b[A\n",
      "Epoch 1:  96%|█████████▌| 1052/1095 [06:36<00:16,  2.65it/s, loss=2.3, v_num=0]\n",
      "Validating:  73%|███████▎  | 114/157 [00:12<00:04,  9.53it/s]\u001b[A\n",
      "Epoch 1:  96%|█████████▋| 1054/1095 [06:36<00:15,  2.66it/s, loss=2.3, v_num=0]\n",
      "Validating:  74%|███████▍  | 116/157 [00:12<00:04,  9.52it/s]\u001b[A\n",
      "Epoch 1:  96%|█████████▋| 1056/1095 [06:36<00:14,  2.66it/s, loss=2.3, v_num=0]\n",
      "Validating:  75%|███████▌  | 118/157 [00:12<00:04,  9.52it/s]\u001b[A\n",
      "Epoch 1:  97%|█████████▋| 1058/1095 [06:36<00:13,  2.67it/s, loss=2.3, v_num=0]\n",
      "Validating:  76%|███████▋  | 120/157 [00:12<00:03,  9.53it/s]\u001b[A\n",
      "Epoch 1:  97%|█████████▋| 1060/1095 [06:37<00:13,  2.67it/s, loss=2.3, v_num=0]\n",
      "Validating:  78%|███████▊  | 122/157 [00:13<00:03,  9.51it/s]\u001b[A\n",
      "Epoch 1:  97%|█████████▋| 1062/1095 [06:37<00:12,  2.67it/s, loss=2.3, v_num=0]\n",
      "Validating:  79%|███████▉  | 124/157 [00:13<00:03,  9.56it/s]\u001b[A\n",
      "Epoch 1:  97%|█████████▋| 1064/1095 [06:37<00:11,  2.68it/s, loss=2.3, v_num=0]\n",
      "Validating:  80%|████████  | 126/157 [00:13<00:03,  9.56it/s]\u001b[A\n",
      "Epoch 1:  97%|█████████▋| 1066/1095 [06:37<00:10,  2.68it/s, loss=2.3, v_num=0]\n",
      "Validating:  82%|████████▏ | 128/157 [00:13<00:03,  9.56it/s]\u001b[A\n",
      "Epoch 1:  98%|█████████▊| 1068/1095 [06:37<00:10,  2.68it/s, loss=2.3, v_num=0]\n",
      "Validating:  83%|████████▎ | 130/157 [00:13<00:02,  9.47it/s]\u001b[A\n",
      "Epoch 1:  98%|█████████▊| 1070/1095 [06:38<00:09,  2.69it/s, loss=2.3, v_num=0]\n",
      "Validating:  84%|████████▍ | 132/157 [00:14<00:02,  9.42it/s]\u001b[A\n",
      "Epoch 1:  98%|█████████▊| 1072/1095 [06:38<00:08,  2.69it/s, loss=2.3, v_num=0]\n",
      "Validating:  85%|████████▌ | 134/157 [00:14<00:02,  9.38it/s]\u001b[A\n",
      "Epoch 1:  98%|█████████▊| 1074/1095 [06:38<00:07,  2.69it/s, loss=2.3, v_num=0]\n",
      "Validating:  87%|████████▋ | 136/157 [00:14<00:02,  9.40it/s]\u001b[A\n",
      "Epoch 1:  98%|█████████▊| 1076/1095 [06:38<00:07,  2.70it/s, loss=2.3, v_num=0]\n",
      "Validating:  88%|████████▊ | 138/157 [00:14<00:02,  9.47it/s]\u001b[A\n",
      "Epoch 1:  98%|█████████▊| 1078/1095 [06:39<00:06,  2.70it/s, loss=2.3, v_num=0]\n",
      "Validating:  89%|████████▉ | 140/157 [00:15<00:01,  9.44it/s]\u001b[A\n",
      "Epoch 1:  99%|█████████▊| 1080/1095 [06:39<00:05,  2.71it/s, loss=2.3, v_num=0]\n",
      "Validating:  90%|█████████ | 142/157 [00:15<00:01,  9.49it/s]\u001b[A\n",
      "Epoch 1:  99%|█████████▉| 1082/1095 [06:39<00:04,  2.71it/s, loss=2.3, v_num=0]\n",
      "Validating:  92%|█████████▏| 144/157 [00:15<00:01,  9.54it/s]\u001b[A\n",
      "Epoch 1:  99%|█████████▉| 1084/1095 [06:39<00:04,  2.71it/s, loss=2.3, v_num=0]\n",
      "Validating:  93%|█████████▎| 146/157 [00:15<00:01,  9.52it/s]\u001b[A\n",
      "Epoch 1:  99%|█████████▉| 1086/1095 [06:39<00:03,  2.72it/s, loss=2.3, v_num=0]\n",
      "Validating:  94%|█████████▍| 148/157 [00:15<00:00,  9.54it/s]\u001b[A\n",
      "Epoch 1:  99%|█████████▉| 1088/1095 [06:40<00:02,  2.72it/s, loss=2.3, v_num=0]\n",
      "Validating:  96%|█████████▌| 150/157 [00:16<00:00,  9.49it/s]\u001b[A\n",
      "Epoch 1: 100%|█████████▉| 1090/1095 [06:40<00:01,  2.72it/s, loss=2.3, v_num=0]\n",
      "Validating:  97%|█████████▋| 152/157 [00:16<00:00,  9.45it/s]\u001b[A\n",
      "Epoch 1: 100%|█████████▉| 1092/1095 [06:40<00:01,  2.73it/s, loss=2.3, v_num=0]\n",
      "Validating:  98%|█████████▊| 154/157 [00:16<00:00,  9.50it/s]\u001b[A\n",
      "Epoch 1: 100%|█████████▉| 1094/1095 [06:40<00:00,  2.73it/s, loss=2.3, v_num=0]\n",
      "Epoch 1: 100%|██████████| 1095/1095 [06:40<00:00,  2.73it/s, loss=2.3, v_num=0]\n",
      "Epoch 2:  86%|████████▌ | 938/1095 [06:06<01:01,  2.56it/s, loss=2.3, v_num=0]\n",
      "Validating: 0it [00:00, ?it/s]\u001b[A\n",
      "Validating:   0%|          | 0/157 [00:00<?, ?it/s]\u001b[A\n",
      "Epoch 2:  86%|████████▌ | 940/1095 [06:07<01:00,  2.56it/s, loss=2.3, v_num=0]\n",
      "Validating:   1%|▏         | 2/157 [00:00<00:37,  4.15it/s]\u001b[A\n",
      "Epoch 2:  86%|████████▌ | 942/1095 [06:07<00:59,  2.57it/s, loss=2.3, v_num=0]\n",
      "Validating:   3%|▎         | 4/157 [00:00<00:22,  6.69it/s]\u001b[A\n",
      "Epoch 2:  86%|████████▌ | 944/1095 [06:07<00:58,  2.57it/s, loss=2.3, v_num=0]\n",
      "Validating:   4%|▍         | 6/157 [00:00<00:18,  8.10it/s]\u001b[A\n",
      "Epoch 2:  86%|████████▋ | 946/1095 [06:07<00:57,  2.57it/s, loss=2.3, v_num=0]\n",
      "Validating:   5%|▌         | 8/157 [00:01<00:17,  8.61it/s]\u001b[A\n",
      "Epoch 2:  87%|████████▋ | 948/1095 [06:07<00:57,  2.58it/s, loss=2.3, v_num=0]\n",
      "Validating:   6%|▋         | 10/157 [00:01<00:16,  9.04it/s]\u001b[A\n",
      "Epoch 2:  87%|████████▋ | 950/1095 [06:08<00:56,  2.58it/s, loss=2.3, v_num=0]\n",
      "Validating:   8%|▊         | 12/157 [00:01<00:15,  9.27it/s]\u001b[A\n",
      "Epoch 2:  87%|████████▋ | 952/1095 [06:08<00:55,  2.58it/s, loss=2.3, v_num=0]\n",
      "Validating:   9%|▉         | 14/157 [00:01<00:15,  9.43it/s]\u001b[A\n",
      "Epoch 2:  87%|████████▋ | 954/1095 [06:08<00:54,  2.59it/s, loss=2.3, v_num=0]\n",
      "Validating:  10%|█         | 16/157 [00:02<00:14,  9.47it/s]\u001b[A\n",
      "Epoch 2:  87%|████████▋ | 956/1095 [06:08<00:53,  2.59it/s, loss=2.3, v_num=0]\n",
      "Validating:  11%|█▏        | 18/157 [00:02<00:14,  9.41it/s]\u001b[A\n",
      "Epoch 2:  87%|████████▋ | 958/1095 [06:08<00:52,  2.60it/s, loss=2.3, v_num=0]\n",
      "Validating:  13%|█▎        | 20/157 [00:02<00:14,  9.43it/s]\u001b[A\n",
      "Epoch 2:  88%|████████▊ | 960/1095 [06:09<00:51,  2.60it/s, loss=2.3, v_num=0]\n",
      "Validating:  14%|█▍        | 22/157 [00:02<00:14,  9.41it/s]\u001b[A\n",
      "Epoch 2:  88%|████████▊ | 962/1095 [06:09<00:51,  2.60it/s, loss=2.3, v_num=0]\n",
      "Validating:  15%|█▌        | 24/157 [00:02<00:14,  9.40it/s]\u001b[A\n",
      "Epoch 2:  88%|████████▊ | 964/1095 [06:09<00:50,  2.61it/s, loss=2.3, v_num=0]\n",
      "Validating:  17%|█▋        | 26/157 [00:03<00:13,  9.38it/s]\u001b[A\n",
      "Epoch 2:  88%|████████▊ | 966/1095 [06:09<00:49,  2.61it/s, loss=2.3, v_num=0]\n",
      "Validating:  18%|█▊        | 28/157 [00:03<00:13,  9.42it/s]\u001b[A\n",
      "Epoch 2:  88%|████████▊ | 968/1095 [06:09<00:48,  2.62it/s, loss=2.3, v_num=0]\n",
      "Validating:  19%|█▉        | 30/157 [00:03<00:13,  9.44it/s]\u001b[A\n",
      "Epoch 2:  89%|████████▊ | 970/1095 [06:10<00:47,  2.62it/s, loss=2.3, v_num=0]\n",
      "Validating:  20%|██        | 32/157 [00:03<00:13,  9.42it/s]\u001b[A\n",
      "Epoch 2:  89%|████████▉ | 972/1095 [06:10<00:46,  2.62it/s, loss=2.3, v_num=0]\n",
      "Validating:  22%|██▏       | 34/157 [00:03<00:13,  9.39it/s]\u001b[A\n",
      "Epoch 2:  89%|████████▉ | 974/1095 [06:10<00:46,  2.63it/s, loss=2.3, v_num=0]\n",
      "Validating:  23%|██▎       | 36/157 [00:04<00:12,  9.43it/s]\u001b[A\n",
      "Epoch 2:  89%|████████▉ | 976/1095 [06:10<00:45,  2.63it/s, loss=2.3, v_num=0]\n",
      "Validating:  24%|██▍       | 38/157 [00:04<00:12,  9.42it/s]\u001b[A\n",
      "Epoch 2:  89%|████████▉ | 978/1095 [06:11<00:44,  2.64it/s, loss=2.3, v_num=0]\n",
      "Validating:  25%|██▌       | 40/157 [00:04<00:12,  9.42it/s]\u001b[A\n",
      "Epoch 2:  89%|████████▉ | 980/1095 [06:11<00:43,  2.64it/s, loss=2.3, v_num=0]\n",
      "Validating:  27%|██▋       | 42/157 [00:04<00:12,  9.39it/s]\u001b[A\n",
      "Epoch 2:  90%|████████▉ | 982/1095 [06:11<00:42,  2.64it/s, loss=2.3, v_num=0]\n",
      "Validating:  28%|██▊       | 44/157 [00:04<00:12,  9.41it/s]\u001b[A\n",
      "Epoch 2:  90%|████████▉ | 984/1095 [06:11<00:41,  2.65it/s, loss=2.3, v_num=0]\n",
      "Validating:  29%|██▉       | 46/157 [00:05<00:11,  9.42it/s]\u001b[A\n",
      "Epoch 2:  90%|█████████ | 986/1095 [06:11<00:41,  2.65it/s, loss=2.3, v_num=0]\n",
      "Validating:  31%|███       | 48/157 [00:05<00:11,  9.41it/s]\u001b[A\n",
      "Epoch 2:  90%|█████████ | 988/1095 [06:12<00:40,  2.66it/s, loss=2.3, v_num=0]\n",
      "Validating:  32%|███▏      | 50/157 [00:05<00:11,  9.40it/s]\u001b[A\n",
      "Epoch 2:  90%|█████████ | 990/1095 [06:12<00:39,  2.66it/s, loss=2.3, v_num=0]\n",
      "Validating:  33%|███▎      | 52/157 [00:05<00:11,  9.41it/s]\u001b[A\n",
      "Epoch 2:  91%|█████████ | 992/1095 [06:12<00:38,  2.66it/s, loss=2.3, v_num=0]\n",
      "Validating:  34%|███▍      | 54/157 [00:06<00:10,  9.44it/s]\u001b[A\n",
      "Epoch 2:  91%|█████████ | 994/1095 [06:12<00:37,  2.67it/s, loss=2.3, v_num=0]\n",
      "Validating:  36%|███▌      | 56/157 [00:06<00:10,  9.40it/s]\u001b[A\n",
      "Epoch 2:  91%|█████████ | 996/1095 [06:12<00:37,  2.67it/s, loss=2.3, v_num=0]\n",
      "Validating:  37%|███▋      | 58/157 [00:06<00:10,  9.41it/s]\u001b[A\n",
      "Epoch 2:  91%|█████████ | 998/1095 [06:13<00:36,  2.67it/s, loss=2.3, v_num=0]\n",
      "Validating:  38%|███▊      | 60/157 [00:06<00:10,  9.42it/s]\u001b[A\n",
      "Epoch 2:  91%|█████████▏| 1000/1095 [06:13<00:35,  2.68it/s, loss=2.3, v_num=0]\n",
      "Validating:  39%|███▉      | 62/157 [00:06<00:10,  9.43it/s]\u001b[A\n",
      "Epoch 2:  92%|█████████▏| 1002/1095 [06:13<00:34,  2.68it/s, loss=2.3, v_num=0]\n",
      "Validating:  41%|████      | 64/157 [00:07<00:09,  9.43it/s]\u001b[A\n",
      "Epoch 2:  92%|█████████▏| 1004/1095 [06:13<00:33,  2.69it/s, loss=2.3, v_num=0]\n",
      "Validating:  42%|████▏     | 66/157 [00:07<00:09,  9.39it/s]\u001b[A\n",
      "Epoch 2:  92%|█████████▏| 1006/1095 [06:14<00:33,  2.69it/s, loss=2.3, v_num=0]\n",
      "Validating:  43%|████▎     | 68/157 [00:07<00:09,  9.40it/s]\u001b[A\n",
      "Epoch 2:  92%|█████████▏| 1008/1095 [06:14<00:32,  2.69it/s, loss=2.3, v_num=0]\n",
      "Validating:  45%|████▍     | 70/157 [00:07<00:09,  9.41it/s]\u001b[A\n",
      "Epoch 2:  92%|█████████▏| 1010/1095 [06:14<00:31,  2.70it/s, loss=2.3, v_num=0]\n",
      "Validating:  46%|████▌     | 72/157 [00:07<00:09,  9.40it/s]\u001b[A\n",
      "Epoch 2:  92%|█████████▏| 1012/1095 [06:14<00:30,  2.70it/s, loss=2.3, v_num=0]\n",
      "Validating:  47%|████▋     | 74/157 [00:08<00:08,  9.37it/s]\u001b[A\n",
      "Epoch 2:  93%|█████████▎| 1014/1095 [06:14<00:29,  2.70it/s, loss=2.3, v_num=0]\n",
      "Validating:  48%|████▊     | 76/157 [00:08<00:08,  9.42it/s]\u001b[A\n",
      "Epoch 2:  93%|█████████▎| 1016/1095 [06:15<00:29,  2.71it/s, loss=2.3, v_num=0]\n",
      "Validating:  50%|████▉     | 78/157 [00:08<00:08,  9.43it/s]\u001b[A\n",
      "Epoch 2:  93%|█████████▎| 1018/1095 [06:15<00:28,  2.71it/s, loss=2.3, v_num=0]\n",
      "Validating:  51%|█████     | 80/157 [00:08<00:08,  9.43it/s]\u001b[A\n",
      "Epoch 2:  93%|█████████▎| 1020/1095 [06:15<00:27,  2.72it/s, loss=2.3, v_num=0]\n",
      "Validating:  52%|█████▏    | 82/157 [00:09<00:08,  9.21it/s]\u001b[A\n",
      "Epoch 2:  93%|█████████▎| 1022/1095 [06:15<00:26,  2.72it/s, loss=2.3, v_num=0]\n",
      "Validating:  54%|█████▎    | 84/157 [00:09<00:07,  9.30it/s]\u001b[A\n",
      "Epoch 2:  94%|█████████▎| 1024/1095 [06:15<00:26,  2.72it/s, loss=2.3, v_num=0]\n",
      "Validating:  55%|█████▍    | 86/157 [00:09<00:07,  9.35it/s]\u001b[A\n",
      "Epoch 2:  94%|█████████▎| 1026/1095 [06:16<00:25,  2.73it/s, loss=2.3, v_num=0]\n",
      "Validating:  56%|█████▌    | 88/157 [00:09<00:07,  9.24it/s]\u001b[A\n",
      "Epoch 2:  94%|█████████▍| 1028/1095 [06:16<00:24,  2.73it/s, loss=2.3, v_num=0]\n",
      "Validating:  57%|█████▋    | 90/157 [00:09<00:07,  9.26it/s]\u001b[A\n",
      "Epoch 2:  94%|█████████▍| 1030/1095 [06:16<00:23,  2.74it/s, loss=2.3, v_num=0]\n",
      "Validating:  59%|█████▊    | 92/157 [00:10<00:07,  9.28it/s]\u001b[A\n",
      "Epoch 2:  94%|█████████▍| 1032/1095 [06:16<00:23,  2.74it/s, loss=2.3, v_num=0]\n",
      "Validating:  60%|█████▉    | 94/157 [00:10<00:06,  9.28it/s]\u001b[A\n",
      "Epoch 2:  94%|█████████▍| 1034/1095 [06:17<00:22,  2.74it/s, loss=2.3, v_num=0]\n",
      "Validating:  61%|██████    | 96/157 [00:10<00:06,  9.30it/s]\u001b[A\n",
      "Epoch 2:  95%|█████████▍| 1036/1095 [06:17<00:21,  2.75it/s, loss=2.3, v_num=0]\n",
      "Validating:  62%|██████▏   | 98/157 [00:10<00:06,  9.36it/s]\u001b[A\n",
      "Epoch 2:  95%|█████████▍| 1038/1095 [06:17<00:20,  2.75it/s, loss=2.3, v_num=0]\n",
      "Validating:  64%|██████▎   | 100/157 [00:10<00:06,  9.34it/s]\u001b[A\n",
      "Epoch 2:  95%|█████████▍| 1040/1095 [06:17<00:19,  2.75it/s, loss=2.3, v_num=0]\n",
      "Validating:  65%|██████▍   | 102/157 [00:11<00:05,  9.34it/s]\u001b[A\n",
      "Epoch 2:  95%|█████████▌| 1042/1095 [06:17<00:19,  2.76it/s, loss=2.3, v_num=0]\n",
      "Validating:  66%|██████▌   | 104/157 [00:11<00:05,  9.36it/s]\u001b[A\n",
      "Epoch 2:  95%|█████████▌| 1044/1095 [06:18<00:18,  2.76it/s, loss=2.3, v_num=0]\n",
      "Validating:  68%|██████▊   | 106/157 [00:11<00:05,  9.32it/s]\u001b[A\n",
      "Epoch 2:  96%|█████████▌| 1046/1095 [06:18<00:17,  2.76it/s, loss=2.3, v_num=0]\n",
      "Validating:  69%|██████▉   | 108/157 [00:11<00:05,  9.30it/s]\u001b[A\n",
      "Epoch 2:  96%|█████████▌| 1048/1095 [06:18<00:16,  2.77it/s, loss=2.3, v_num=0]\n",
      "Validating:  70%|███████   | 110/157 [00:12<00:05,  9.30it/s]\u001b[A\n",
      "Epoch 2:  96%|█████████▌| 1050/1095 [06:18<00:16,  2.77it/s, loss=2.3, v_num=0]\n",
      "Validating:  71%|███████▏  | 112/157 [00:12<00:04,  9.28it/s]\u001b[A\n",
      "Epoch 2:  96%|█████████▌| 1052/1095 [06:18<00:15,  2.78it/s, loss=2.3, v_num=0]\n",
      "Validating:  73%|███████▎  | 114/157 [00:12<00:04,  9.31it/s]\u001b[A\n",
      "Epoch 2:  96%|█████████▋| 1054/1095 [06:19<00:14,  2.78it/s, loss=2.3, v_num=0]\n",
      "Validating:  74%|███████▍  | 116/157 [00:12<00:04,  9.36it/s]\u001b[A\n",
      "Epoch 2:  96%|█████████▋| 1056/1095 [06:19<00:14,  2.78it/s, loss=2.3, v_num=0]\n",
      "Validating:  75%|███████▌  | 118/157 [00:12<00:04,  9.37it/s]\u001b[A\n",
      "Epoch 2:  97%|█████████▋| 1058/1095 [06:19<00:13,  2.79it/s, loss=2.3, v_num=0]\n",
      "Validating:  76%|███████▋  | 120/157 [00:13<00:03,  9.35it/s]\u001b[A\n",
      "Epoch 2:  97%|█████████▋| 1060/1095 [06:19<00:12,  2.79it/s, loss=2.3, v_num=0]\n",
      "Validating:  78%|███████▊  | 122/157 [00:13<00:03,  9.35it/s]\u001b[A\n",
      "Epoch 2:  97%|█████████▋| 1062/1095 [06:20<00:11,  2.79it/s, loss=2.3, v_num=0]\n",
      "Validating:  79%|███████▉  | 124/157 [00:13<00:03,  9.39it/s]\u001b[A\n",
      "Epoch 2:  97%|█████████▋| 1064/1095 [06:20<00:11,  2.80it/s, loss=2.3, v_num=0]\n",
      "Validating:  80%|████████  | 126/157 [00:13<00:03,  9.41it/s]\u001b[A\n",
      "Epoch 2:  97%|█████████▋| 1066/1095 [06:20<00:10,  2.80it/s, loss=2.3, v_num=0]\n",
      "Validating:  82%|████████▏ | 128/157 [00:13<00:03,  9.35it/s]\u001b[A\n",
      "Epoch 2:  98%|█████████▊| 1068/1095 [06:20<00:09,  2.81it/s, loss=2.3, v_num=0]\n",
      "Validating:  83%|████████▎ | 130/157 [00:14<00:02,  9.36it/s]\u001b[A\n",
      "Epoch 2:  98%|█████████▊| 1070/1095 [06:20<00:08,  2.81it/s, loss=2.3, v_num=0]\n",
      "Validating:  84%|████████▍ | 132/157 [00:14<00:02,  9.38it/s]\u001b[A\n",
      "Epoch 2:  98%|█████████▊| 1072/1095 [06:21<00:08,  2.81it/s, loss=2.3, v_num=0]\n",
      "Validating:  85%|████████▌ | 134/157 [00:14<00:02,  9.39it/s]\u001b[A\n",
      "Epoch 2:  98%|█████████▊| 1074/1095 [06:21<00:07,  2.82it/s, loss=2.3, v_num=0]\n",
      "Validating:  87%|████████▋ | 136/157 [00:14<00:02,  9.38it/s]\u001b[A\n",
      "Epoch 2:  98%|█████████▊| 1076/1095 [06:21<00:06,  2.82it/s, loss=2.3, v_num=0]\n",
      "Validating:  88%|████████▊ | 138/157 [00:15<00:02,  9.40it/s]\u001b[A\n",
      "Epoch 2:  98%|█████████▊| 1078/1095 [06:21<00:06,  2.82it/s, loss=2.3, v_num=0]\n",
      "Validating:  89%|████████▉ | 140/157 [00:15<00:01,  9.37it/s]\u001b[A\n",
      "Epoch 2:  99%|█████████▊| 1080/1095 [06:21<00:05,  2.83it/s, loss=2.3, v_num=0]\n",
      "Validating:  90%|█████████ | 142/157 [00:15<00:01,  9.37it/s]\u001b[A\n",
      "Epoch 2:  99%|█████████▉| 1082/1095 [06:22<00:04,  2.83it/s, loss=2.3, v_num=0]\n",
      "Validating:  92%|█████████▏| 144/157 [00:15<00:01,  9.33it/s]\u001b[A\n",
      "Epoch 2:  99%|█████████▉| 1084/1095 [06:22<00:03,  2.84it/s, loss=2.3, v_num=0]\n",
      "Validating:  93%|█████████▎| 146/157 [00:15<00:01,  9.36it/s]\u001b[A\n",
      "Epoch 2:  99%|█████████▉| 1086/1095 [06:22<00:03,  2.84it/s, loss=2.3, v_num=0]\n",
      "Validating:  94%|█████████▍| 148/157 [00:16<00:00,  9.39it/s]\u001b[A\n",
      "Epoch 2:  99%|█████████▉| 1088/1095 [06:22<00:02,  2.84it/s, loss=2.3, v_num=0]\n",
      "Validating:  96%|█████████▌| 150/157 [00:16<00:00,  9.42it/s]\u001b[A\n",
      "Epoch 2: 100%|█████████▉| 1090/1095 [06:22<00:01,  2.85it/s, loss=2.3, v_num=0]\n",
      "Validating:  97%|█████████▋| 152/157 [00:16<00:00,  9.45it/s]\u001b[A\n",
      "Epoch 2: 100%|█████████▉| 1092/1095 [06:23<00:01,  2.85it/s, loss=2.3, v_num=0]\n",
      "Validating:  98%|█████████▊| 154/157 [00:16<00:00,  9.43it/s]\u001b[A\n",
      "Epoch 2: 100%|█████████▉| 1094/1095 [06:23<00:00,  2.85it/s, loss=2.3, v_num=0]\n",
      "Epoch 2: 100%|██████████| 1095/1095 [06:23<00:00,  2.85it/s, loss=2.3, v_num=0]\n",
      "Epoch 2: 100%|██████████| 1095/1095 [06:24<00:00,  2.84it/s, loss=2.3, v_num=0]\n"
     ]
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "num_classes = 10\n",
    "in_channels = 1\n",
    "epochs = 3\n",
    "model = VGG16(in_channels, num_classes)\n",
    "\n",
    "print(summary(model, input_size=(2, in_channels, 28, 28)))\n",
    "\n",
    "trainer = pl.Trainer(\n",
    "    default_root_dir=\"logs\",\n",
    "    gpus=(1 if torch.cuda.is_available() else 0),\n",
    "    max_epochs=epochs,\n",
    "    logger=pl.loggers.TensorBoardLogger(\"logs/\", name=\"vgg\", version=0),\n",
    ")\n",
    "\n",
    "trainer.fit(model, train_dataloader=data.train_dl, val_dataloaders=data.val_dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}