Skip to content
This repository was archived by the owner on Mar 19, 2024. It is now read-only.

Commit dd9971a

Browse files
iseesselfacebook-github-bot
authored andcommitted
Add augly transformation support (#442)
Summary: Pull Request resolved: #442 Add support for augly transformations. Similar to apex, I made augly install optional for users and didn't add it to requirements.txt -- let me know what you think about this. Reviewed By: prigoyal, QuentinDuval Differential Revision: D31462923 fbshipit-source-id: ce793f1adc432b3f1ea08acf4b3f66daa88215a8
1 parent 83d859f commit dd9971a

File tree

6 files changed

+134
-9
lines changed

6 files changed

+134
-9
lines changed

.circleci/config.yml

+19-7
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ cpu: &cpu
1616
environment:
1717
TERM: xterm
1818
machine:
19-
image: default
19+
image: ubuntu-1604:201903-01
2020
resource_class: medium
2121

2222
gpu: &gpu
@@ -37,8 +37,8 @@ install_python: &install_python
3737
working_directory: ~/
3838
command: |
3939
pyenv versions
40-
pyenv install 3.6.2
41-
pyenv global 3.6.2
40+
pyenv install -f 3.7.0
41+
pyenv global 3.7.0
4242
4343
update_gcc7: &update_gcc7
4444
- run:
@@ -107,6 +107,17 @@ install_vissl_dep: &install_vissl_dep
107107
# Update this since classy_vision seems to need it.
108108
pip install --progress-bar off --upgrade iopath
109109
110+
# Must install python3-magic as per documentation:
111+
# https://github.com/facebookresearch/AugLy#installation
112+
install_augly: &install_augly
113+
- run:
114+
name: Install augly
115+
working_directory: ~/vissl
116+
command: |
117+
pip install augly
118+
sudo apt-get update
119+
sudo apt-get install python3-magic
120+
110121
install_apex_gpu: &install_apex_gpu
111122
- run:
112123
name: Install Apex
@@ -153,17 +164,18 @@ jobs:
153164
# Cache the vissl_venv directory that contains dependencies
154165
- restore_cache:
155166
keys:
156-
- v5-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
167+
- v6-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
157168

158169
- <<: *install_vissl_dep
170+
- <<: *install_augly
159171
- <<: *install_classy_vision
160172
- <<: *install_apex_cpu
161173
- <<: *pip_list
162174

163175
- save_cache:
164176
paths:
165177
- ~/vissl_venv
166-
key: v5-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
178+
key: v6-cpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}
167179

168180
- <<: *install_vissl
169181

@@ -196,7 +208,7 @@ jobs:
196208
# Download and cache dependencies
197209
- restore_cache:
198210
keys:
199-
- v5-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}
211+
- v6-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}
200212

201213
- <<: *install_vissl_dep
202214
- <<: *install_classy_vision
@@ -211,7 +223,7 @@ jobs:
211223
- save_cache:
212224
paths:
213225
- ~/vissl_venv
214-
key: v5-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}
226+
key: v6-gpu-dependencies-{{ checksum "requirements.txt" }}-{{ checksum "setup.py" }}-{{ checksum "docker/common/install_apex.sh" }}
215227

216228
- <<: *install_vissl
217229

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# @package _global_
2+
config:
3+
DATA:
4+
TRAIN:
5+
TRANSFORMS:
6+
- name: ImgReplicatePil
7+
num_times: 2
8+
- name: RandomResizedCrop
9+
size: 224
10+
- name: RandomHorizontalFlip
11+
p: 0.5
12+
- name: ImgPilColorDistortion
13+
strength: 1.0
14+
- name: ImgPilGaussianBlur
15+
p: 0.5
16+
radius_min: 0.1
17+
radius_max: 2.0
18+
- name: Blur
19+
transform_type: "augly"
20+
radius: 2.0
21+
p: 1.0
22+
- name: ToTensor
23+
- name: Normalize
24+
mean: [0.485, 0.456, 0.406]
25+
std: [0.229, 0.224, 0.225]

tests/test_transforms.py

+18
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
from vissl.data.ssl_transforms.img_pil_to_multicrop import ImgPilToMultiCrop
1313
from vissl.data.ssl_transforms.img_pil_to_tensor import ImgToTensor
1414
from vissl.data.ssl_transforms.mnist_img_pil_to_rgb_mode import MNISTImgPil2RGB
15+
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
16+
from vissl.utils.test_utils import (
17+
in_temporary_directory,
18+
run_integration_test,
19+
)
1520

1621

1722
RAND_TENSOR = (torch.rand((224, 224, 3)) * 255).to(dtype=torch.uint8)
@@ -77,3 +82,16 @@ def test_img_pil_to_multicrop(self):
7782
self.assertEqual((224, 224), crop.size)
7883
for crop in crops[2:]:
7984
self.assertEqual((96, 96), crop.size)
85+
86+
def test_augly_transforms(self):
87+
cfg = compose_hydra_configuration(
88+
[
89+
"config=test/cpu_test/test_cpu_resnet_simclr.yaml",
90+
"+config/test/transforms=augly_transforms_example",
91+
],
92+
)
93+
args, config = convert_to_attrdict(cfg)
94+
95+
with in_temporary_directory() as _:
96+
# Test that the training runs with an augly transformation.
97+
run_integration_test(config)

vissl/data/ssl_transforms/ssl_transforms_wrapper.py

+40-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55

66
from typing import Any, Dict
77

8-
from classy_vision.dataset.transforms import build_transform, register_transform
8+
from classy_vision.dataset.transforms import (
9+
build_transform as build_classy_transform,
10+
register_transform,
11+
)
912
from classy_vision.dataset.transforms.classy_transform import ClassyTransform
13+
from vissl.utils.misc import is_augly_available
1014

15+
if is_augly_available():
16+
import augly.image as imaugs # NOQA
1117

1218
# Below the transforms that require passing the labels as well. This is specifc
1319
# to SSL only where we automatically generate the labels for training. All other
@@ -108,7 +114,7 @@ def __init__(
108114
"""
109115
self.indices = set(indices)
110116
self.name = args["name"]
111-
self.transform = build_transform(args)
117+
self.transform = self._build_transform(args)
112118
self.transform_receives_entire_batch = transform_receives_entire_batch
113119
self.transforms_with_labels = transform_types["TRANSFORMS_WITH_LABELS"]
114120
self.transforms_with_copies = transform_types["TRANSFORMS_WITH_COPIES"]
@@ -117,6 +123,38 @@ def __init__(
117123
]
118124
self.transforms_with_grouping = transform_types["TRANSFORMS_WITH_GROUPING"]
119125

126+
def _build_transform(self, args):
127+
if "transform_type" not in args:
128+
# Default to classy transform.
129+
return build_classy_transform(args)
130+
elif args["transform_type"] == "augly":
131+
# Build augly transform.
132+
return self._build_augly_transform(args)
133+
else:
134+
raise RuntimeError(
135+
f"Transform type: { args.transform_type } is not supported"
136+
)
137+
138+
def _build_augly_transform(self, args):
139+
assert is_augly_available(), "Please pip install augly."
140+
141+
# the name should be available in augly.image
142+
# if users specify the transform name in snake case,
143+
# we need to convert it to title case.
144+
name = args["name"]
145+
146+
if not hasattr(imaugs, name):
147+
# Try converting name to title case.
148+
name = name.title().replace("_", "")
149+
150+
assert hasattr(imaugs, name), f"{name} isn't a registered tranform for augly."
151+
152+
# Delete superfluous keys.
153+
del args["name"]
154+
del args["transform_type"]
155+
156+
return getattr(imaugs, name)(**args)
157+
120158
def _is_transform_with_labels(self):
121159
"""
122160
_TRANSFORMS_WITH_LABELS = ["ImgRotatePil", "ShuffleImgPatches"]

vissl/utils/hydra_config.py

+12
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from omegaconf import DictConfig, OmegaConf
1313
from vissl.config import AttrDict, check_cfg_version
1414
from vissl.utils.io import save_file
15+
from vissl.utils.misc import is_augly_available
1516

1617

1718
def save_attrdict_to_disk(cfg: AttrDict):
@@ -462,6 +463,16 @@ def infer_losses_config(cfg):
462463
return cfg
463464

464465

466+
def assert_transforms(cfg):
467+
for transforms in [cfg.DATA.TRAIN.TRANSFORMS, cfg.DATA.TEST.TRANSFORMS]:
468+
for transform in transforms:
469+
if "transform_type" in transform:
470+
assert transform["transform_type"] in [None, "augly"]
471+
472+
if transform["transform_type"] == "augly":
473+
assert is_augly_available(), "Please pip install augly."
474+
475+
465476
def infer_and_assert_hydra_config(cfg):
466477
"""
467478
Infer values of few parameters in the config file using the value of other config parameters
@@ -480,6 +491,7 @@ def infer_and_assert_hydra_config(cfg):
480491
"""
481492
cfg = infer_losses_config(cfg)
482493
cfg = infer_learning_rate(cfg)
494+
assert_transforms(cfg)
483495

484496
# pass the seed to cfg["MODEL"] so that model init on different nodes can
485497
# use the same seed.

vissl/utils/misc.py

+20
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import os
99
import random
10+
import sys
1011
import tempfile
1112
import time
1213
from functools import partial, wraps
@@ -80,6 +81,25 @@ def is_apex_available():
8081
return apex_available
8182

8283

84+
def is_augly_available():
85+
"""
86+
Check if apex is available with simple python imports.
87+
"""
88+
try:
89+
assert sys.version_info >= (
90+
3,
91+
7,
92+
0,
93+
), "Please upgrade your python version to 3.7 or higher to use Augly."
94+
95+
import augly.image # NOQA
96+
97+
augly_available = True
98+
except ImportError:
99+
augly_available = False
100+
return augly_available
101+
102+
83103
def find_free_tcp_port():
84104
"""
85105
Find the free port that can be used for Rendezvous on the local machine.

0 commit comments

Comments
 (0)