Skip to content

Commit f2f7c83

Browse files
authored
Adds docs and tests. (#45)
* implements detach module using the structure of clone module * detach/clone progress * tests for clone and detach, tests should be moved to different folder but I was encountering import errors * Force detach_module to work in-place. * Move tests to proper folder. * line addition and renaming vars * testing if gradient update works * have one optimizer_step function * update getting_started to contain intro to l2l and meta-learning * update mkdocs and pydocmd config files to restore equation rendering and fix equations in getting_started tutorial * Add new implementation of Ant/HalfCheetah envs. * Add HumanoidForwardBackward. * Add HumanoidDirection env. * Add particles, remove unused stuff. * Add environment License * Clean up RL examples. * Remove mkdocs.yml from source. * Add docs for clone and detach module. * Fix docs. * Add detach implementation. * Add magic box doc, remove top-level import algos. * Add docs for MAML. * Add docs for MetaSGD. * Started doc on task generator. * Add documentation for l2l.data. * Fix equations in docs. * Add travis * Add more docs. * Add integration tests. * Update README with Travis. * Add requirements.txt * Fix task generator out-of-range sampling. * Remove tqdm from tests * Add pandas dependency. * Add requests dependency. * Better camera for Ant and HalfCheetah envs. * Removed unused environments. * Minor fix promp. * Clean up readme. * Clean up readme. * Add demo gif. * Update readme. * Improve docs.
1 parent 4de9032 commit f2f7c83

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+742
-511
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,4 @@ venv.bak/
106106
# mypy
107107
.mypy_cache/
108108
data/*
109+
docs/mkdocs.yml

.travis.yml

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
language: python
2+
3+
os:
4+
- linux
5+
6+
python:
7+
- 3.6
8+
# - 3.7
9+
- "3.6-dev"
10+
- "3.7-dev"
11+
12+
# matrix:
13+
# include:
14+
# - os: osx
15+
# language: generic
16+
# env: PYTHON=3.5.0
17+
# - os: osx
18+
# language: generic
19+
# env: PYTHON=3.6.0
20+
# - os: osx
21+
# language: generic
22+
# env: PYTHON=3.7.0
23+
24+
before_install: |
25+
if [ "$TRAVIS_OS_NAME" == "osx" ]; then
26+
brew update
27+
# Per the `pyenv homebrew recommendations <https://github.com/yyuu/pyenv/wiki#suggested-build-environment>`_.
28+
brew install openssl readline
29+
# See https://docs.travis-ci.com/user/osx-ci-environment/#A-note-on-upgrading-packages.
30+
# I didn't do this above because it works and I'm lazy.
31+
brew outdated pyenv || brew upgrade pyenv
32+
# virtualenv doesn't work without pyenv knowledge. venv in Python 3.3
33+
# doesn't provide Pip by default. So, use `pyenv-virtualenv <https://github.com/yyuu/pyenv-virtualenv/blob/master/README.md>`_.
34+
brew install pyenv-virtualenv
35+
pyenv install $PYTHON
36+
# I would expect something like ``pyenv init; pyenv local $PYTHON`` or
37+
# ``pyenv shell $PYTHON`` would work, but ``pyenv init`` doesn't seem to
38+
# modify the Bash environment. ??? So, I hand-set the variables instead.
39+
export PYENV_VERSION=$PYTHON
40+
export PATH="/Users/travis/.pyenv/shims:${PATH}"
41+
pyenv-virtualenv venv
42+
source venv/bin/activate
43+
# A manual check that the correct version of Python is running.
44+
python --version
45+
fi
46+
47+
install:
48+
- pip install -U pip && pip install --progress-bar off -r requirements.txt && pip install pycodestyle
49+
50+
script:
51+
- make tests

README.md

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,30 @@
1-
<p align="center"><img src="https://raw.githubusercontent.com/seba-1511/learn2learn/gh-pages/assets/img/l2l-full.png" height="150px" /></p>
1+
<p align="center"><img src="https://raw.githubusercontent.com/seba-1511/learn2learn/gh-pages/assets/img/l2l-full.png" height="120px" /></p>
22

33
--------------------------------------------------------------------------------
44

5+
[![Build Status](https://travis-ci.com/learnables/learn2learn.svg?branch=master)](https://travis-ci.com/learnables/learn2learn)
6+
57
learn2learn is a PyTorch library for meta-learning implementations.
6-
It was developed during the [first PyTorch Hackathon](http://pytorchmpk.devpost.com/).
8+
9+
The goal of meta-learning is to enable agents to *learn how to learn*.
10+
That is, we would like our agents to become better learners as they solve more and more tasks.
11+
For example, the animation below shows an agent that learns to run after a only one parameter update.
12+
13+
<p align="center"><img src="assets/img/halfcheetah.gif" height="250px" /></p>
14+
15+
**Features**
16+
17+
learn2learn provides high- and low-level utilities for meta-learning.
18+
The high-level utilities allow arbitrary users to take advantage of exisiting meta-learning algorithms.
19+
The low-level utilities enable researchers to develop new and better meta-learning algorithms.
20+
21+
Some features of learn2learn include:
22+
23+
* Modular API: implement your own training loops with our low-level utilities.
24+
* Provides various meta-learning algorithms (e.g. MAML, FOMAML, MetaSGD, ProtoNets, DiCE)
25+
* Task generator with unified API, compatible with torchvision, torchtext, torchaudio, and cherry.
26+
* Provides standardized meta-learning tasks for vision (Omniglot, mini-ImageNet), reinforcement learning (Particles, Mujoco), and even text (news classification).
27+
* 100% compatible with PyTorch -- use your own modules, datasets, or libraries!
728

829
# Installation
930

@@ -14,7 +35,7 @@ pip install learn2learn
1435
# API Demo
1536

1637
The following is an example of using the high-level MAML implementation on MNIST.
17-
For more algorithms and lower-level utilities, please refer to [the documentation](http://learn2learn.net/docs/learn2learn/) or the [examples](https://github.com/learnables/learn2learn/tree/master/examples).
38+
For more algorithms and lower-level utilities, please refer to the [documentation](http://learn2learn.net/docs/learn2learn/) or the [examples](https://github.com/learnables/learn2learn/tree/master/examples).
1839

1940
~~~python
2041
import learn2learn as l2l
@@ -27,7 +48,7 @@ task_generator = l2l.data.TaskGenerator(mnist,
2748
classes=[0, 1, 4, 6, 8, 9],
2849
tasks=10)
2950
model = Net()
30-
maml = l2l.MAML(model, lr=1e-3, first_order=False)
51+
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
3152
opt = optim.Adam(maml.parameters(), lr=4e-3)
3253

3354
for iteration in range(num_iterations):
@@ -41,7 +62,7 @@ for iteration in range(num_iterations):
4162

4263
# Compute evaluation loss
4364
evaluation_task = task_generator.sample(shots=1,
44-
classes=adaptation_task.sampled_classes)
65+
task=adaptation_task.sampled_task)
4566
evaluation_error = compute_loss(evaluation_task)
4667

4768
# Meta-update the model parameters
@@ -50,6 +71,6 @@ for iteration in range(num_iterations):
5071
opt.step()
5172
~~~
5273

53-
# Acknowledgements
74+
### Acknowledgements
5475

5576
1. The RL environments are adapted from Tristan Deleu's [implementations](https://github.com/tristandeleu/pytorch-maml-rl) and from the ProMP [repository](https://github.com/jonasrothfuss/ProMP/). Both shared with permission, under the MIT License.

docs/pydocmd.yml

+39-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,37 @@ site_name: "learn2learn"
66
# documented. Higher indentation leads to smaller header size.
77
generate:
88
- docs/learn2learn.md:
9-
- learn2learn+
9+
- learn2learn.clone_module
10+
- learn2learn.detach_module
11+
- learn2learn.magic_box
12+
- docs/learn2learn.data.md:
13+
- learn2learn.data.MetaDataset
14+
- learn2learn.data.TaskGenerator++
15+
- docs/learn2learn.algorithms.md:
16+
- learn2learn.algorithms.MAML++
17+
- learn2learn.algorithms.maml_update
18+
- learn2learn.algorithms.MetaSGD++
19+
- learn2learn.algorithms.meta_sgd_update
20+
- docs/learn2learn.gym.md:
21+
- learn2learn.gym++
22+
- learn2learn.gym.envs.mujoco
23+
- learn2learn.gym.envs.mujoco.HalfCheetahForwardBackwardEnv
24+
- learn2learn.gym.envs.mujoco.AntForwardBackwardEnv
25+
- learn2learn.gym.envs.mujoco.AntDirectionEnv
26+
- learn2learn.gym.envs.mujoco.HumanoidForwardBackwardEnv
27+
- learn2learn.gym.envs.mujoco.HumanoidDirectionEnv
28+
- learn2learn.gym.envs.particles
29+
- learn2learn.gym.envs.particles.Particles2DEnv
30+
- docs/learn2learn.vision.md:
31+
- learn2learn.vision.models
32+
- learn2learn.vision.models.OmniglotFC
33+
- learn2learn.vision.models.OmniglotCNN
34+
- learn2learn.vision.datasets
35+
- learn2learn.vision.datasets.FullOmniglot
36+
- learn2learn.vision.transforms
37+
- learn2learn.vision.transforms.RandomDiscreteRotation
38+
- docs/learn2learn.text.md:
39+
- learn2learn.text.datasets.NewsClassification
1040

1141
# MkDocs pages configuration. The `<<` operator is sugar added by pydocmd
1242
# that allows you to use an external Markdown file (eg. your project's README)
@@ -18,8 +48,13 @@ pages:
1848
- Getting Started: tutorials/getting_started.md
1949
- Documentation:
2050
- learn2learn: docs/learn2learn.md
21-
- Examples: https://github.com/seba-1511/learn2learn/tree/master/examples
22-
- GitHub: https://github.com/seba-1511/learn2learn/
51+
- learn2learn.algorithms: docs/learn2learn.algorithms.md
52+
- learn2learn.data: docs/learn2learn.data.md
53+
- learn2learn.gym: docs/learn2learn.gym.md
54+
- learn2learn.text: docs/learn2learn.text.md
55+
- learn2learn.vision: docs/learn2learn.vision.md
56+
- Examples: https://github.com/learnables/learn2learn/tree/master/examples
57+
- GitHub: https://github.com/learnables/learn2learn/
2358

2459
# These options all show off their default values. You don't have to add
2560
# them to your configuration if you're fine with the default.
@@ -34,7 +69,7 @@ theme:
3469
custom_dir: 'l2l_theme/'
3570
highlightjs: true
3671
loader: pydocmd.loader.PythonLoader
37-
preprocessor: pydocmd.preprocessors.simple.Preprocessor
72+
preprocessor: pydocmd.preprocessor.Preprocessor
3873
# Whether to output headers as markdown or HTML. Used to workaround
3974
# https://github.com/NiklasRosenstein/pydoc-markdown/issues/11. The default is
4075
# to generate HTML with unique and meaningful id tags, which can't be done with
4.09 MB
Loading

docs/source/tutorials/getting_started.md

+3
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@ For more information about specific meta-learning algorithms, please refer to th
3030
appropriate tutorial.
3131

3232
# How to Use L2L
33+
3334
## Installing
35+
3436
A pip package is available, updated periodically. Use the command:
3537

3638
```pip install learn2learn```
@@ -48,6 +50,7 @@ encounter a problem, feel free to an open an [issue](https://github.com/learnabl
4850
look into it.
4951

5052
## Source Files
53+
5154
Examples of learn2learn in action can be found [here](https://github.com/learnables/learn2learn/tree/master/examples).
5255
The source code for algorithm implementations is also available [here](https://github.com/learnables/learn2learn/tree/master/learn2learn/algorithms).
5356

examples/maml_toy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def forward(self, x=None):
3131
def main():
3232
task_dist = dist.Normal(th.zeros(2 * DIM), th.ones(2 * DIM))
3333
model = Model()
34-
maml = l2l.MAML(model, lr=1e-2)
34+
maml = l2l.algorithms.MAML(model, lr=1e-2)
3535
opt = optim.Adam(maml.parameters())
3636

3737
for i in range(TIMESTEPS):

examples/rl/dist_promp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def make_env():
124124
policy = DiagNormalPolicy(input_size=env.state_size,
125125
output_size=env.action_size,
126126
hiddens=[64, 64])
127-
meta_learner = l2l.MAML(policy, lr=meta_lr)
127+
meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr)
128128
baseline = LinearValue(env.state_size, env.action_size)
129129
opt = optim.Adam(meta_learner.parameters(), lr=meta_lr)
130130

examples/rl/maml_dice.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def make_env():
8989
env.seed(seed)
9090
env = ch.envs.Torch(env)
9191
policy = DiagNormalPolicy(env.state_size, env.action_size)
92-
meta_learner = l2l.MAML(policy, lr=meta_lr)
92+
meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr)
9393
baseline = LinearValue(env.state_size, env.action_size)
9494
opt = optim.Adam(policy.parameters(), lr=meta_lr)
9595
all_rewards = []

examples/rl/metasgd_a2c.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def make_env():
7979
env.seed(seed)
8080
env = ch.envs.Torch(env)
8181
policy = DiagNormalPolicy(env.state_size, env.action_size)
82-
meta_learner = l2l.MetaSGD(policy, lr=meta_lr)
82+
meta_learner = l2l.algorithms.MetaSGD(policy, lr=meta_lr)
8383
baseline = LinearValue(env.state_size, env.action_size)
8484
opt = optim.Adam(policy.parameters(), lr=meta_lr)
8585
all_rewards = []

examples/rl/promp.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def precompute_quantities(states, actions, old_policy, new_policy):
7272

7373

7474
def main(
75-
env_name='Particles2D-v1',
75+
env_name='AntDirection-v1',
7676
adapt_lr=0.1,
7777
meta_lr=3e-4,
78-
adapt_steps=1,
78+
adapt_steps=3,
7979
num_iterations=1000,
8080
meta_bsz=40,
8181
adapt_bsz=20,
@@ -94,7 +94,9 @@ def main(
9494
th.manual_seed(seed)
9595

9696
def make_env():
97-
return gym.make(env_name)
97+
env = gym.make(env_name)
98+
env = ch.envs.ActionSpaceScaler(env)
99+
return env
98100

99101
env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
100102
env.seed(seed)
@@ -104,7 +106,7 @@ def make_env():
104106
output_size=env.action_size,
105107
hiddens=[64, 64],
106108
activation='tanh')
107-
meta_learner = l2l.MAML(policy, lr=meta_lr)
109+
meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr)
108110
baseline = LinearValue(env.state_size, env.action_size)
109111
opt = optim.Adam(meta_learner.parameters(), lr=meta_lr)
110112

examples/vision/maml_miniimagenet.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def main(
7676
test_generator = l2l.data.TaskGenerator(dataset=test_dataset, ways=ways)
7777

7878
# Create model
79-
model = l2l.models.MiniImagenetCNN(ways)
79+
model = l2l.vision.models.MiniImagenetCNN(ways)
8080
model.to(device)
81-
maml = l2l.MAML(model, lr=fast_lr, first_order=False)
81+
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
8282
opt = optim.Adam(maml.parameters(), meta_lr)
8383
loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')
8484

@@ -95,7 +95,7 @@ def main(
9595
learner = maml.clone()
9696
adaptation_data = train_generator.sample(shots=shots)
9797
evaluation_data = train_generator.sample(shots=shots,
98-
classes=adaptation_data.sampled_classes)
98+
task=adaptation_data.sampled_task)
9999
evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
100100
evaluation_data,
101101
learner,
@@ -110,7 +110,7 @@ def main(
110110
learner = maml.clone()
111111
adaptation_data = valid_generator.sample(shots=shots)
112112
evaluation_data = valid_generator.sample(shots=shots,
113-
classes=adaptation_data.sampled_classes)
113+
task=adaptation_data.sampled_task)
114114
evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
115115
evaluation_data,
116116
learner,
@@ -124,7 +124,7 @@ def main(
124124
learner = maml.clone()
125125
adaptation_data = test_generator.sample(shots=shots)
126126
evaluation_data = test_generator.sample(shots=shots,
127-
classes=adaptation_data.sampled_classes)
127+
task=adaptation_data.sampled_task)
128128
evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
129129
evaluation_data,
130130
learner,

examples/vision/maml_omniglot.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def main(
7171
test_generator = l2l.data.TaskGenerator(dataset=omniglot, ways=ways, classes=classes[1200:])
7272

7373
# Create model
74-
model = l2l.models.OmniglotFC(28 ** 2, ways)
74+
model = l2l.vision.models.OmniglotFC(28 ** 2, ways)
7575
model.to(device)
76-
maml = l2l.MAML(model, lr=fast_lr, first_order=False)
76+
maml = l2l.algorithms.MAML(model, lr=fast_lr, first_order=False)
7777
opt = optim.Adam(maml.parameters(), meta_lr)
7878
loss = nn.CrossEntropyLoss(size_average=True, reduction='mean')
7979

@@ -90,7 +90,7 @@ def main(
9090
learner = maml.clone()
9191
adaptation_data = train_generator.sample(shots=shots)
9292
evaluation_data = train_generator.sample(shots=shots,
93-
classes=adaptation_data.sampled_classes)
93+
task=adaptation_data.sampled_task)
9494
evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
9595
evaluation_data,
9696
learner,
@@ -105,7 +105,7 @@ def main(
105105
learner = maml.clone()
106106
adaptation_data = valid_generator.sample(shots=shots)
107107
evaluation_data = valid_generator.sample(shots=shots,
108-
classes=adaptation_data.sampled_classes)
108+
task=adaptation_data.sampled_task)
109109
evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
110110
evaluation_data,
111111
learner,
@@ -119,7 +119,7 @@ def main(
119119
learner = maml.clone()
120120
adaptation_data = test_generator.sample(shots=shots)
121121
evaluation_data = test_generator.sample(shots=shots,
122-
classes=adaptation_data.sampled_classes)
122+
task=adaptation_data.sampled_task)
123123
evaluation_error, evaluation_accuracy = fast_adapt(adaptation_data,
124124
evaluation_data,
125125
learner,

examples/vision/meta_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def main(lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5
7171

7272
model = Net(ways)
7373
model.to(device)
74-
meta_model = l2l.MAML(model, lr=maml_lr)
74+
meta_model = l2l.algorithms.MAML(model, lr=maml_lr)
7575
opt = optim.Adam(meta_model.parameters(), lr=lr)
7676
loss_func = nn.NLLLoss(reduction="sum")
7777

examples/vision/proto_net.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torchvision import transforms
1818

1919
import learn2learn as l2l
20-
from learn2learn.models import OmniglotCNN
20+
from learn2learn.vision.models import OmniglotCNN
2121
from learn2learn.vision.datasets.full_omniglot import FullOmniglot
2222

2323

learn2learn/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from . import algorithms
44
from . import data
55
from . import gym
6-
from . import models
76
from . import text
87
from . import vision
98
from ._version import __version__
10-
from .algorithms import MAML, MetaSGD, magic_box
119
from .utils import *

learn2learn/algorithms/__init__.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
22

3-
from .dice import magic_box
4-
from .maml import MAML
5-
from .meta_sgd import MetaSGD
3+
from .maml import MAML, maml_update
4+
from .meta_sgd import MetaSGD, meta_sgd_update

learn2learn/algorithms/dice.py

-9
This file was deleted.

0 commit comments

Comments
 (0)