Skip to content

Commit b2b4b51

Browse files
seba-1511nightlessbaronewinapun
authored
Add Lightning Bolts for MAML, ProtoNet, ANIL, MetaOptNet. (#202)
* Add l2l.nn.PrototypicalClassifier implementation. Co-authored-by: Varad Pimpalkhute <[email protected]> * Version bump. * Add SVMClassifier and tests. Co-authored-by: Varad Pimpalkhute <[email protected]> * Move MetaOptNet text to CPU. * Fix ResNet12 for new API. * Add l2l.data.EpisodicBatcher. Co-authored-by: Ewina Pun <[email protected]> * Add PyTorch Lightning Implementations of MAML, ANIL, ProtoNet, and MetaOptNet. Co-authored-by: Varad Pimpalkhute <[email protected]> Co-authored-by: Ewina Pun <[email protected]> * Add test for Lightning algorithms. Co-auhtored-by: Varad Pimpalkhute <[email protected]> * Update travis setting. * Update docs and CHANGELOG. * Add training script to use Lightning modules. Co-authored-by: Varad Pimpalkhute <[email protected]> * Fix memory leak during validation and test. * Ensure train for ProtoNet and MetaOptNet. * Move EpisodicBatcher to utils. * Minor updates. * Fix merge conflict in _version.py * Change the behaviour of episodic batcher. * Add docs. * Add progress bar and testing callbacks. * Fix linting. * Fix tests. * Update tests to PY3.7+ * Update tests to PY3.7 only. * Update tests. * Move lightning tests to notravis. Co-authored-by: Varad Pimpalkhute <[email protected]> Co-authored-by: Ewina Pun <[email protected]>
1 parent 365d15a commit b2b4b51

28 files changed

+1822
-4
lines changed

.github/workflows/python_unittest.yaml

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
strategy:
1010
matrix:
1111
os: [ubuntu-18.04, ubuntu-latest, macos-latest]
12-
python: ['3.6', '3.7']
12+
python: ['3.7', '3.8']
1313
pytorch: ['1.3.0', '1.4.0', '1.5.0', '1.6.0', '1.7.0']
1414
include:
1515
- pytorch: '1.3.0'
@@ -22,6 +22,9 @@ jobs:
2222
torchvision: '0.7.0'
2323
- pytorch: '1.7.0'
2424
torchvision: '0.8.0'
25+
exclude:
26+
- pytorch: '1.3.0'
27+
python: '3.8'
2528

2629
steps:
2730
- name: Clone Repository

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,5 @@ learn2learn/data/*.so
112112
alltests.txt
113113
docs/MUJOCO_LOG.TXT
114114
token.pickle
115+
lightning_logs/**
116+
examples/vision/lightning/lightning_logs/**

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Added
1212

13+
* PyTorch Lightning interface to MAML, ANIL, ProtoNet, MetaOptNet.
14+
* Automatic batcher for Lighting: `l2l.data.EpisodicBatcher`.
15+
* `l2l.nn.PrototypicalClassifier` and `l2l.nn.SVMClassifier`.
16+
1317
### Changed
1418

1519
### Fixed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ alltests:
4545

4646
docs:
4747
rm -f docs/mkdocs.yml
48-
python scripts/compile_paper_list.py
48+
#python scripts/compile_paper_list.py
4949
cd docs && pydocmd build && pydocmd serve
5050

5151
docs-deploy:

docs/pydocmd.yml

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ generate:
3030
- learn2learn.algorithms.MAML++
3131
- learn2learn.algorithms.MetaSGD++
3232
- learn2learn.algorithms.GBML++
33+
- learn2learn.algorithms.LightningMAML
34+
- learn2learn.algorithms.LightningANIL
35+
- learn2learn.algorithms.LightningPrototypicalNetworks
36+
- learn2learn.algorithms.LightningMetaOptNet
3337
- docs/learn2learn.gym.md:
3438
- learn2learn.gym++:
3539
- learn2learn.gym.MetaEnv
@@ -60,6 +64,8 @@ generate:
6064
- learn2learn.nn.Lambda
6165
- learn2learn.nn.Flatten
6266
- learn2learn.nn.Scale
67+
- learn2learn.nn.PrototypicalClassifier
68+
- learn2learn.nn.SVClassifier
6369
- learn2learn.nn.KroneckerLinear
6470
- learn2learn.nn.KroneckerRNN
6571
- learn2learn.nn.KroneckerLSTM

examples/vision/lightning/Makefile

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
2+
.PHONY: *
3+
4+
proto-cifarfs:
5+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
6+
--dataset=cifarfs \
7+
--max_epochs=1000 \
8+
--min_epochs=1000 \
9+
--algorithm=protonet \
10+
--lr=0.0016 \
11+
--meta_batch_size=16 \
12+
--train_shots=5 \
13+
--train_ways=15 \
14+
--train_queries=15 \
15+
--test_shots=5 \
16+
--test_ways=5 \
17+
--test_queries=5 \
18+
--seed=42
19+
20+
maml-cifarfs:
21+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
22+
--dataset=cifarfs \
23+
--max_epochs=25000 \
24+
--min_epochs=25000 \
25+
--algorithm=maml \
26+
--adaptation_steps=5 \
27+
--adaptation_lr=0.05 \
28+
--lr=0.001 \
29+
--meta_batch_size=16 \
30+
--train_shots=5 \
31+
--train_ways=5 \
32+
--train_queries=5 \
33+
--test_shots=5 \
34+
--test_ways=5 \
35+
--test_queries=5 \
36+
--seed=42
37+
38+
anil-cifarfs:
39+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
40+
--dataset=cifarfs \
41+
--max_epochs=2000 \
42+
--min_epochs=2000 \
43+
--algorithm=anil \
44+
--adaptation_steps=5 \
45+
--adaptation_lr=0.05 \
46+
--lr=0.001 \
47+
--meta_batch_size=16 \
48+
--train_shots=5 \
49+
--train_ways=5 \
50+
--train_queries=5 \
51+
--test_shots=5 \
52+
--test_ways=5 \
53+
--test_queries=5 \
54+
--seed=42
55+
56+
metaoptnet-cifarfs:
57+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
58+
--dataset=cifarfs \
59+
--max_epochs=10000 \
60+
--min_epochs=10000 \
61+
--algorithm=metaoptnet \
62+
--train_shots=5 \
63+
--train_ways=5 \
64+
--train_queries=15 \
65+
--test_shots=5 \
66+
--test_ways=5 \
67+
--test_queries=5 \
68+
--seed=42
69+
70+
proto-mi:
71+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
72+
--dataset='mini-imagenet' \
73+
--max_epochs=10000 \
74+
--min_epochs=10000 \
75+
--algorithm=protonet \
76+
--distance_metric='euclidean' \
77+
--meta_batch_size=8 \
78+
--lr=0.005 \
79+
--train_shots=5 \
80+
--train_ways=20 \
81+
--train_queries=15 \
82+
--test_shots=5 \
83+
--test_ways=5 \
84+
--test_queries=5 \
85+
--data_parallel \
86+
--seed=42
87+
88+
maml-mi:
89+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
90+
--dataset=mini-imagenet \
91+
--max_epochs=35000 \
92+
--min_epochs=35000 \
93+
--algorithm=maml \
94+
--adaptation_steps=5 \
95+
--adaptation_lr=0.02 \
96+
--lr=0.0003 \
97+
--meta_batch_size=16 \
98+
--train_shots=5 \
99+
--train_ways=5 \
100+
--train_queries=5 \
101+
--test_shots=5 \
102+
--test_ways=5 \
103+
--test_queries=5 \
104+
--data_parallel \
105+
--seed=42
106+
107+
anil-mi:
108+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
109+
--dataset=mini-imagenet \
110+
--lr=0.001 \
111+
--meta_batch_size=16 \
112+
--max_epochs=25000 \
113+
--min_epochs=25000 \
114+
--algorithm=anil \
115+
--adaptation_lr=0.1 \
116+
--adaptation_steps=5 \
117+
--train_shots=5 \
118+
--train_ways=5 \
119+
--train_queries=5 \
120+
--test_shots=5 \
121+
--test_ways=5 \
122+
--test_queries=5 \
123+
--seed=42
124+
125+
metaoptnet-mi:
126+
CUDA_VISIBLE_DEVICES=$(GPU) python main.py \
127+
--dataset=mini-imagenet \
128+
--lr=3e-4 \
129+
--max_epochs=40000 \
130+
--min_epochs=40000 \
131+
--algorithm=metaoptnet \
132+
--train_shots=15 \
133+
--train_ways=5 \
134+
--train_queries=5 \
135+
--test_shots=5 \
136+
--test_ways=5 \
137+
--test_queries=5 \
138+
--data_parallel \
139+
--seed=42

examples/vision/lightning/main.py

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Example for running few-shot algorithms with the PyTorch Lightning wrappers.
5+
"""
6+
7+
import learn2learn as l2l
8+
import pytorch_lightning as pl
9+
from argparse import ArgumentParser
10+
from learn2learn.algorithms import (
11+
LightningPrototypicalNetworks,
12+
LightningMetaOptNet,
13+
LightningMAML,
14+
LightningANIL,
15+
)
16+
from learn2learn.utils.lightning import EpisodicBatcher
17+
18+
19+
def main():
20+
parser = ArgumentParser(conflict_handler="resolve", add_help=True)
21+
# add model and trainer specific args
22+
parser = LightningPrototypicalNetworks.add_model_specific_args(parser)
23+
parser = LightningMetaOptNet.add_model_specific_args(parser)
24+
parser = LightningMAML.add_model_specific_args(parser)
25+
parser = LightningANIL.add_model_specific_args(parser)
26+
parser = pl.Trainer.add_argparse_args(parser)
27+
28+
# add script-specific args
29+
parser.add_argument("--algorithm", type=str, default="protonet")
30+
parser.add_argument("--dataset", type=str, default="mini-imagenet")
31+
parser.add_argument("--root", type=str, default="~/data")
32+
parser.add_argument("--meta_batch_size", type=int, default=16)
33+
parser.add_argument("--seed", type=int, default=42)
34+
args = parser.parse_args()
35+
dict_args = vars(args)
36+
37+
pl.seed_everything(args.seed)
38+
39+
# Create tasksets using the benchmark interface
40+
if False and args.dataset in ["mini-imagenet", "tiered-imagenet"]:
41+
data_augmentation = "lee2019"
42+
else:
43+
data_augmentation = "normalize"
44+
tasksets = l2l.vision.benchmarks.get_tasksets(
45+
name=args.dataset,
46+
train_samples=args.train_queries + args.train_shots,
47+
train_ways=args.train_ways,
48+
test_samples=args.test_queries + args.test_shots,
49+
test_ways=args.test_ways,
50+
root=args.root,
51+
data_augmentation=data_augmentation,
52+
)
53+
episodic_data = EpisodicBatcher(
54+
tasksets.train,
55+
tasksets.validation,
56+
tasksets.test,
57+
epoch_length=args.meta_batch_size * 10,
58+
)
59+
60+
# init model
61+
if args.dataset in ["mini-imagenet", "tiered-imagenet"]:
62+
model = l2l.vision.models.ResNet12(output_size=args.train_ways)
63+
else: # CIFAR-FS, FC100
64+
model = l2l.vision.models.CNN4(
65+
output_size=args.train_ways,
66+
hidden_size=64,
67+
embedding_size=64*4,
68+
)
69+
features = model.features
70+
classifier = model.classifier
71+
72+
# init algorithm
73+
if args.algorithm == "protonet":
74+
algorithm = LightningPrototypicalNetworks(features=features, **dict_args)
75+
elif args.algorithm == "maml":
76+
algorithm = LightningMAML(model, **dict_args)
77+
elif args.algorithm == "anil":
78+
algorithm = LightningANIL(features, classifier, **dict_args)
79+
elif args.algorithm == "metaoptnet":
80+
algorithm = LightningMetaOptNet(features, **dict_args)
81+
82+
trainer = pl.Trainer.from_argparse_args(
83+
args,
84+
gpus=1,
85+
accumulate_grad_batches=args.meta_batch_size,
86+
callbacks=[
87+
l2l.utils.lightning.TrackTestAccuracyCallback(),
88+
l2l.utils.lightning.NoLeaveProgressBar(),
89+
],
90+
)
91+
trainer.fit(model=algorithm, datamodule=episodic_data)
92+
trainer.test(ckpt_path="best")
93+
94+
95+
if __name__ == "__main__":
96+
main()

learn2learn/algorithms/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,10 @@
77
from .maml import MAML, maml_update
88
from .meta_sgd import MetaSGD, meta_sgd_update
99
from .gbml import GBML
10+
from .lightning import (
11+
LightningEpisodicModule,
12+
LightningMAML,
13+
LightningANIL,
14+
LightningPrototypicalNetworks,
15+
LightningMetaOptNet,
16+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/usr/bin/env python3
2+
3+
r"""
4+
Standardized implementations of few-shot learning algorithms,
5+
compatible with PyTorch Lightning.
6+
"""
7+
8+
from .lightning_episodic_module import LightningEpisodicModule
9+
from .lightning_maml import LightningMAML
10+
from .lightning_anil import LightningANIL
11+
from .lightning_protonet import LightningPrototypicalNetworks
12+
from .lightning_metaoptnet import LightningMetaOptNet

0 commit comments

Comments
 (0)