Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSL mnist #1368

Merged
merged 4 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- MNIST dataset for SSL banchmark ([#1368](https://github.com/catalyst-team/catalyst/pull/1368))

### Changed

Expand All @@ -20,14 +20,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
-


## [21.11] - 2021-11-30

### Added

- MultiVAE RecSys example ([#1340](https://github.com/catalyst-team/catalyst/pull/1340))`
- MultiVAE RecSys example ([#1340](https://github.com/catalyst-team/catalyst/pull/1340))
- Returned `resume` support - resolved [#1193](https://github.com/catalyst-team/catalyst/issues/1193) ([#1349](https://github.com/catalyst-team/catalyst/pull/1349))
- Smoothing dice loss to contrib ([#1344](https://github.com/catalyst-team/catalyst/pull/1344))
- `profile` flag for `runner.train` ([#1348](https://github.com/catalyst-team/catalyst/pull/1348))
Expand Down
8 changes: 8 additions & 0 deletions CITATION
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@misc{catalyst,
author = {Kolesnikov, Sergey},
title = {Catalyst - Accelerated deep learning R&D},
year = {2018},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/catalyst-team/catalyst}},
}
38 changes: 28 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,23 @@ so you can create something new rather than write yet another train loop.
- Part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/)

<details>
<summary>Catalyst at PyTorch Ecosystem Day</summary>
<summary>Catalyst at PyTorch Ecosystem Day 2021</summary>
<p>

[![Catalyst poster](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/Catalyst-PTED21.png)](https://github.com/catalyst-team/catalyst)

</p>
</details>

<details>
<summary>Catalyst at PyTorch Developer Day 2021</summary>
<p>

[![Catalyst poster](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/Catalyst-PTDD21.png)](https://github.com/catalyst-team/catalyst)

</p>
</details>

----

## Getting started
Expand Down Expand Up @@ -213,6 +222,7 @@ best practices for your deep learning research and development.

### Documentation
- [master](https://catalyst-team.github.io/catalyst/)
- [21.11](https://catalyst-team.github.io/catalyst/v21.11/index.html)
- [21.10](https://catalyst-team.github.io/catalyst/v21.10/index.html)
- [21.09](https://catalyst-team.github.io/catalyst/v21.09/index.html)
- [21.08](https://catalyst-team.github.io/catalyst/v21.08/index.html)
Expand All @@ -221,15 +231,23 @@ best practices for your deep learning research and development.
- [21.05](https://catalyst-team.github.io/catalyst/v21.05/index.html) ([Catalyst — A PyTorch Framework for Accelerated Deep Learning R&D](https://medium.com/pytorch/catalyst-a-pytorch-framework-for-accelerated-deep-learning-r-d-ad9621e4ca88?source=friends_link&sk=885b4409aecab505db0a63b06f19dcef))
- [21.04/21.04.1](https://catalyst-team.github.io/catalyst/v21.04/index.html), [21.04.2](https://catalyst-team.github.io/catalyst/v21.04.2/index.html)
- [21.03](https://catalyst-team.github.io/catalyst/v21.03/index.html), [21.03.1/21.03.2](https://catalyst-team.github.io/catalyst/v21.03.1/index.html)
- [20.12](https://catalyst-team.github.io/catalyst/v20.12/index.html)
- [20.11](https://catalyst-team.github.io/catalyst/v20.11/index.html)
- [20.10](https://catalyst-team.github.io/catalyst/v20.10/index.html)
- [20.09](https://catalyst-team.github.io/catalyst/v20.09/index.html)
- [20.08.2](https://catalyst-team.github.io/catalyst/v20.08.2/index.html)
- [20.07](https://catalyst-team.github.io/catalyst/v20.07/index.html) ([dev blog: 20.07 release](https://medium.com/pytorch/catalyst-dev-blog-20-07-release-fb489cd23e14?source=friends_link&sk=7ab92169658fe9a9e1c44068f28cc36c))
- [20.06](https://catalyst-team.github.io/catalyst/v20.06/index.html)
- [20.05](https://catalyst-team.github.io/catalyst/v20.05/index.html), [20.05.1](https://catalyst-team.github.io/catalyst/v20.05.1/index.html)
- [20.04](https://catalyst-team.github.io/catalyst/v20.04/index.html), [20.04.1](https://catalyst-team.github.io/catalyst/v20.04.1/index.html), [20.04.2](https://catalyst-team.github.io/catalyst/v20.04.2/index.html)
- <details>
<summary>2020 edition</summary>
<p>

- [20.12](https://catalyst-team.github.io/catalyst/v20.12/index.html)
- [20.11](https://catalyst-team.github.io/catalyst/v20.11/index.html)
- [20.10](https://catalyst-team.github.io/catalyst/v20.10/index.html)
- [20.09](https://catalyst-team.github.io/catalyst/v20.09/index.html)
- [20.08.2](https://catalyst-team.github.io/catalyst/v20.08.2/index.html)
- [20.07](https://catalyst-team.github.io/catalyst/v20.07/index.html) ([dev blog: 20.07 release](https://medium.com/pytorch/catalyst-dev-blog-20-07-release-fb489cd23e14?source=friends_link&sk=7ab92169658fe9a9e1c44068f28cc36c))
- [20.06](https://catalyst-team.github.io/catalyst/v20.06/index.html)
- [20.05](https://catalyst-team.github.io/catalyst/v20.05/index.html), [20.05.1](https://catalyst-team.github.io/catalyst/v20.05.1/index.html)
- [20.04](https://catalyst-team.github.io/catalyst/v20.04/index.html), [20.04.1](https://catalyst-team.github.io/catalyst/v20.04.1/index.html), [20.04.2](https://catalyst-team.github.io/catalyst/v20.04.2/index.html)

</p>
</details>


### Minimal Examples

Expand Down
4 changes: 3 additions & 1 deletion examples/self_supervised/barlow_twins.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@

# create model and optimizer
model = get_contrastive_model(
in_size=DATASETS[args.dataset]["in_size"], feature_dim=args.feature_dim
in_size=DATASETS[args.dataset]["in_size"],
in_channels=DATASETS[args.dataset]["in_channels"],
feature_dim=args.feature_dim,
)
optimizer = optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-6)

Expand Down
8 changes: 6 additions & 2 deletions examples/self_supervised/byol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
model = nn.ModuleDict(
{
"online": get_contrastive_model(
in_size=DATASETS[args.dataset]["in_size"], feature_dim=args.feature_dim
in_size=DATASETS[args.dataset]["in_size"],
in_channels=DATASETS[args.dataset]["in_channels"],
feature_dim=args.feature_dim,
),
"target": get_contrastive_model(
in_size=DATASETS[args.dataset]["in_size"], feature_dim=args.feature_dim
in_size=DATASETS[args.dataset]["in_size"],
in_channels=DATASETS[args.dataset]["in_channels"],
feature_dim=args.feature_dim,
),
}
)
Expand Down
20 changes: 20 additions & 0 deletions examples/self_supervised/check.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env bash

# pip install catalyst[cv]==21.11
# pip install catalyst[ml]==21.11

export NUM_EPOCHS=20
export BATCH_SIZE=256
export LEARNING_RATE=0.001

for DATASET in "MNIST"; do
for METHOD in "barlow_twins" "byol" "simCLR" "supervised_contrastive"; do
python $METHOD.py \
--dataset $DATASET \
--logdir="./logs/$DATASET/$METHOD" \
--batch-size=$BATCH_SIZE \
--epochs=$NUM_EPOCHS \
--learning-rate=$LEARNING_RATE \
--verbose
done
done
23 changes: 21 additions & 2 deletions examples/self_supervised/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,21 @@ def conv_block(in_channels, out_channels, pool=False):
return nn.Sequential(*layers)


def resnet_mnist(in_size: int, in_channels: int, out_features: int, size: int = 16):
sz, sz2, sz4 = size, size * 2, size * 4
out_size = (((in_size // 16) * 16) ** 2 * 4) // size
return nn.Sequential(
conv_block(in_channels, sz),
conv_block(sz, sz2, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz2, sz2), conv_block(sz2, sz2))),
conv_block(sz2, sz4, pool=True),
ResidualBlock(nn.Sequential(conv_block(sz4, sz4), conv_block(sz4, sz4))),
nn.Sequential(
nn.MaxPool2d(4), nn.Flatten(), nn.Dropout(0.2), nn.Linear(out_size, out_features)
),
)


def resnet9(in_size: int, in_channels: int, out_features: int, size: int = 16):
sz, sz2, sz4, sz8 = size, size * 2, size * 4, size * 8
assert in_size >= 32, "The graph is not valid for images with resolution lower then 32x32."
Expand All @@ -155,20 +170,24 @@ def resnet9(in_size: int, in_channels: int, out_features: int, size: int = 16):


def get_contrastive_model(
in_size: int, feature_dim: int, encoder_dim: int = 512, hidden_dim: int = 512
in_size: int, in_channels: int, feature_dim: int, encoder_dim: int = 512, hidden_dim: int = 512
) -> ContrastiveModel:
"""Init contrastive model based on parsed parametrs.

Args:
in_size: size of an image (in_size x in_size)
in_channels: number of channels in an image
feature_dim: dimensinality of contrative projection
encoder_dim: dimensinality of encoder output
hidden_dim: dimensinality of encoder-contrative projection

Returns:
ContrstiveModel instance
"""
encoder = resnet9(in_size=in_size, in_channels=3, out_features=encoder_dim)
try:
encoder = resnet9(in_size=in_size, in_channels=in_channels, out_features=encoder_dim)
except:
encoder = resnet_mnist(in_size=in_size, in_channels=in_channels, out_features=encoder_dim)
projection_head = nn.Sequential(
nn.Linear(encoder_dim, hidden_dim, bias=False),
nn.ReLU(inplace=True),
Expand Down
93 changes: 54 additions & 39 deletions examples/self_supervised/datasets.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,101 @@
import torchvision
from torchvision.datasets import CIFAR10, CIFAR100, STL10
from torchvision import datasets, transforms

DATASETS = {
"MNIST": {
"dataset": datasets.MNIST,
"in_size": 28,
"in_channels": 1,
"train_transform": transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
"valid_transform": transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
),
},
"CIFAR-10": {
"dataset": CIFAR10,
"dataset": datasets.CIFAR10,
"in_size": 32,
"train_transform": torchvision.transforms.Compose(
"in_channels": 3,
"train_transform": transforms.Compose(
[
torchvision.transforms.RandomApply(
transforms.RandomApply(
[
torchvision.transforms.ColorJitter(
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
],
p=0.8,
),
torchvision.transforms.RandomGrayscale(p=0.1),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282)),
transforms.RandomGrayscale(p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
"valid_transform": torchvision.transforms.Compose(
"valid_transform": transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
},
"CIFAR-100": {
"dataset": CIFAR100,
"dataset": datasets.CIFAR100,
"in_size": 32,
"train_transform": torchvision.transforms.Compose(
"in_channels": 3,
"train_transform": transforms.Compose(
[
torchvision.transforms.RandomApply(
transforms.RandomApply(
[
torchvision.transforms.ColorJitter(
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
],
p=0.8,
),
torchvision.transforms.RandomGrayscale(p=0.1),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.480, 0.448, 0.398), (0.277, 0.269, 0.282)),
transforms.RandomGrayscale(p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
"valid_transform": torchvision.transforms.Compose(
"valid_transform": transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
[0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
},
"STL10": {
"dataset": STL10,
"dataset": datasets.STL10,
"in_size": 96,
"train_transform": torchvision.transforms.Compose(
"train_transform": transforms.Compose(
[
torchvision.transforms.RandomApply(
transforms.RandomApply(
[
torchvision.transforms.ColorJitter(
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1
)
],
p=0.8,
),
torchvision.transforms.RandomGrayscale(p=0.1),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)),
transforms.RandomGrayscale(p=0.1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)),
]
),
"valid_transform": torchvision.transforms.Compose(
"valid_transform": transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)),
transforms.ToTensor(),
transforms.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27)),
]
),
},
Expand Down
4 changes: 2 additions & 2 deletions examples/self_supervised/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# pip install catalyst[cv]==21.11
# pip install catalyst[ml]==21.11

export NUM_EPOCHS=1
export BATCH_SIZE=32
export NUM_EPOCHS=20
export BATCH_SIZE=256
export LEARNING_RATE=0.001

for DATASET in "CIFAR-10" "CIFAR-100" "STL10"; do
Expand Down
4 changes: 3 additions & 1 deletion examples/self_supervised/simCLR.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

# create model and optimizer
model = get_contrastive_model(
in_size=DATASETS[args.dataset]["in_size"], feature_dim=args.feature_dim
in_size=DATASETS[args.dataset]["in_size"],
in_channels=DATASETS[args.dataset]["in_channels"],
feature_dim=args.feature_dim,
)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

Expand Down
4 changes: 3 additions & 1 deletion examples/self_supervised/supervised_contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def concat(*tensors):

# create model and optimizer
model = get_contrastive_model(
in_size=DATASETS[args.dataset]["in_size"], feature_dim=args.feature_dim
in_size=DATASETS[args.dataset]["in_size"],
in_channels=DATASETS[args.dataset]["in_channels"],
feature_dim=args.feature_dim,
)
optimizer = Adam(model.parameters(), lr=args.learning_rate)

Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-cv.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
imageio>=2.5.0
opencv-python-headless>=4.1.1.26
scikit-image>=0.16.1
scikit-image<0.19.0>=0.16.1
torchvision>=0.4.1
Pillow>=6.1 # torchvision fix (https://github.com/python-pillow/Pillow/issues/4130)