Skip to content

Commit

Permalink
python 3.6
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeyk committed Jan 19, 2021
1 parent d914477 commit da0b75e
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 45 deletions.
5 changes: 0 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
data/downloaded
data/processed
data/interim
data/raw/emnist/matlab*
data/raw/fsdl_handwriting/pages
data/raw/iam/iamdb
data/raw/iam/iamdb.zip
data/raw/nltk

# Logs
training/logs
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: fsdl-text-recognizer-2021
channels:
- defaults
dependencies:
- python=3.8
- python=3.6 # Google Colab is still on Python 3.6
- cudatoolkit=10.1
- cudnn=7.6
- pip
Expand Down
2 changes: 1 addition & 1 deletion lab1/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ We show:
- Basic directory layout
- PyTorch MLP and LeNet models
- PyTorch-Lightning based training
- A single point of entry to running experiments: `python training/run_experiment.py`
- A single point of entry to running experiments: `python run_experiment.py`
- `python training/run_experiment.py --max_epochs=10 --gpus='0,1' --accelerator=ddp --num_workers=20 --model_class=lenet.LeNet`
- Logs in Tensorboard: `tensorboard --logdir=training/logs`

Expand Down
14 changes: 9 additions & 5 deletions lab1/text_recognizer/data/base_data_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Base DataModule class."""
from pathlib import Path
from typing import Dict
import argparse
import os

Expand All @@ -21,15 +22,18 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)


def _download_raw_dataset(metadata):
if os.path.exists(metadata["filename"]):
def _download_raw_dataset(metadata: Dict, dl_dirname: Path) -> Path:
dl_dirname.mkdir(parents=True, exist_ok=True)
filename = dl_dirname / metadata["filename"]
if filename.exists():
return
print(f"Downloading raw dataset from {metadata['url']}...")
util.download_url(metadata["url"], metadata["filename"])
print(f"Downloading raw dataset from {metadata['url']} to {filename}...")
util.download_url(metadata["url"], filename)
print("Computing SHA-256...")
sha256 = util.compute_sha256(metadata["filename"])
sha256 = util.compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError("Downloaded data file SHA-256 does not match that listed in metadata document.")
return filename


BATCH_SIZE = 128
Expand Down
10 changes: 8 additions & 2 deletions lab1/text_recognizer/data/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from torchvision.datasets import MNIST
from torchvision import transforms

from text_recognizer.data import BaseDataModule
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info

DOWNLOADED_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded"


class MNISTDataModule(BaseDataModule):
Expand All @@ -16,7 +18,7 @@ class MNISTDataModule(BaseDataModule):

def __init__(self, args: argparse.Namespace) -> None:
super().__init__(args)
self.data_dir = "./data/downloaded"
self.data_dir = DOWNLOADED_DATA_DIRNAME
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
self.dims = (1, 28, 28) # dims are returned when calling `.size()` on this object.
self.output_dims = (1,)
Expand All @@ -32,3 +34,7 @@ def setup(self, stage=None):
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.data_train, self.data_val = random_split(mnist_full, [55000, 5000])
self.data_test = MNIST(self.data_dir, train=False, transform=self.transform)


if __name__ == "__main__":
load_and_print_info(MNISTDataModule)
4 changes: 2 additions & 2 deletions lab1/text_recognizer/models/mlp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Dict
import argparse
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -21,7 +21,7 @@ def __init__(
super().__init__()
self.args = vars(args) if args is not None else {}

input_dim = math.prod(data_config["input_dims"])
input_dim = np.prod(data_config["input_dims"])
num_classes = len(data_config["mapping"])

fc1_dim = self.args.get("fc1", FC1_DIM)
Expand Down
80 changes: 55 additions & 25 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ argon2-cffi==20.1.0
astroid==2.4.2
# via pylint
async-generator==1.10
# via nbclient
# via
# anyio
# nbclient
attrs==20.3.0
# via
# -c requirements/prod.txt
Expand Down Expand Up @@ -49,10 +51,19 @@ click==7.1.2
# wandb
configparser==5.0.1
# via wandb
contextvars==2.4
# via sniffio
cycler==0.10.0
# via matplotlib
dataclasses==0.8
# via
# -c requirements/prod.txt
# anyio
# black
decorator==4.4.2
# via ipython
# via
# ipython
# traitlets
defusedxml==0.6.0
# via nbconvert
docker-pycreds==0.4.0
Expand All @@ -67,7 +78,7 @@ entrypoints==0.3
# via nbconvert
gitdb==4.0.5
# via gitpython
gitpython==3.1.11
gitpython==3.1.12
# via
# bandit
# wandb
Expand All @@ -76,21 +87,30 @@ idna==2.10
# -c requirements/prod.txt
# anyio
# requests
immutables==0.14
# via contextvars
importlib-metadata==3.4.0
# via
# -c requirements/prod.txt
# jsonschema
# pluggy
# pytest
# stevedore
iniconfig==1.1.1
# via pytest
ipykernel==5.4.2
ipykernel==5.4.3
# via notebook
ipython-genutils==0.2.0
# via
# jupyter-server
# nbformat
# notebook
# traitlets
ipython==7.19.0
ipython==7.16.1
# via
# ipykernel
# jupyterlab
isort==5.6.4
isort==5.7.0
# via pylint
itermplot==0.331
# via -r requirements/dev.in
Expand All @@ -112,7 +132,7 @@ jsonschema==3.2.0
# via
# jupyterlab-server
# nbformat
jupyter-client==6.1.7
jupyter-client==6.1.11
# via
# ipykernel
# jupyter-server
Expand All @@ -126,16 +146,16 @@ jupyter-core==4.7.0
# nbconvert
# nbformat
# notebook
jupyter-server==1.1.3
jupyter-server==1.2.2
# via
# jupyterlab
# jupyterlab-server
# nbclassic
jupyterlab-pygments==0.1.2
# via nbconvert
jupyterlab-server==2.0.0
jupyterlab-server==2.1.2
# via jupyterlab
jupyterlab==3.0.0
jupyterlab==3.0.5
# via -r requirements/dev.in
kiwisolver==1.3.1
# via matplotlib
Expand All @@ -159,15 +179,15 @@ mypy-extensions==0.4.3
# mypy
mypy==0.790
# via -r requirements/dev.in
nbclassic==0.2.5
nbclassic==0.2.6
# via jupyterlab
nbclient==0.5.1
# via nbconvert
nbconvert==6.0.7
# via
# jupyter-server
# notebook
nbformat==5.0.8
nbformat==5.1.2
# via
# jupyter-server
# nbclient
Expand All @@ -177,9 +197,9 @@ nest-asyncio==1.4.3
# via nbclient
nltk==3.5
# via -r requirements/dev.in
notebook==6.1.6
notebook==6.2.0
# via nbclassic
numpy==1.19.4
numpy==1.19.5
# via
# -c requirements/prod.txt
# itermplot
Expand All @@ -199,13 +219,15 @@ parso==0.8.1
# via jedi
pathspec==0.8.1
# via black
pathtools==0.1.2
# via watchdog
pbr==5.5.1
# via stevedore
pexpect==4.8.0
# via ipython
pickleshare==0.7.5
# via ipython
pillow==8.0.1
pillow==8.1.0
# via
# -c requirements/prod.txt
# -r requirements/dev.in
Expand All @@ -218,7 +240,7 @@ prometheus-client==0.9.0
# notebook
promise==2.3
# via wandb
prompt-toolkit==3.0.8
prompt-toolkit==3.0.10
# via ipython
protobuf==3.14.0
# via
Expand All @@ -238,7 +260,7 @@ pycparser==2.20
# via cffi
pydocstyle==5.1.1
# via -r requirements/dev.in
pygments==2.7.3
pygments==2.7.4
# via
# ipython
# jupyterlab-pygments
Expand Down Expand Up @@ -267,7 +289,7 @@ pyyaml==5.3.1
# bandit
# dparse
# wandb
pyzmq==20.0.0
pyzmq==21.0.1
# via
# jupyter-client
# jupyter-server
Expand All @@ -282,7 +304,7 @@ requests==2.25.1
# jupyterlab-server
# safety
# wandb
safety==1.10.0
safety==1.10.3
# via -r requirements/dev.in
scipy==1.5.4
# via -r requirements/dev.in
Expand All @@ -308,6 +330,7 @@ six==1.15.0
# promise
# protobuf
# python-dateutil
# traitlets
# wandb
smmap==3.0.4
# via gitdb
Expand All @@ -319,7 +342,7 @@ stevedore==3.3.0
# via bandit
subprocess32==3.5.4
# via wandb
terminado==0.9.1
terminado==0.9.2
# via
# jupyter-server
# notebook
Expand All @@ -341,11 +364,11 @@ tornado==6.1
# jupyterlab
# notebook
# terminado
tqdm==4.55.0
tqdm==4.56.0
# via
# -c requirements/prod.txt
# nltk
traitlets==5.0.5
traitlets==4.3.3
# via
# ipykernel
# ipython
Expand All @@ -356,30 +379,37 @@ traitlets==5.0.5
# nbconvert
# nbformat
# notebook
typed-ast==1.4.1
typed-ast==1.4.2
# via
# astroid
# black
# mypy
typing-extensions==3.7.4.3
# via
# -c requirements/prod.txt
# anyio
# black
# importlib-metadata
# mypy
urllib3==1.26.2
# via
# -c requirements/prod.txt
# requests
# sentry-sdk
wandb==0.10.12
wandb==0.10.14
# via -r requirements/dev.in
watchdog==1.0.2
watchdog==0.10.4
# via wandb
wcwidth==0.2.5
# via prompt-toolkit
webencodings==0.5.1
# via bleach
wrapt==1.12.1
# via astroid
zipp==3.4.0
# via
# -c requirements/prod.txt
# importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools
Loading

0 comments on commit da0b75e

Please sign in to comment.