Skip to content

Commit

Permalink
Fix clone_module when submodules share parameters. (#176)
Browse files Browse the repository at this point in the history
* Fix clone_module with shared parameters.

* Add _notravis for benchmarks too.

* Update CHANGELOG.
  • Loading branch information
seba-1511 authored Aug 30, 2020
1 parent 69558e0 commit 6a32de2
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

* Fix `clone_module` for Modules whose submodules share parameters.


## v0.1.2

Expand Down
32 changes: 27 additions & 5 deletions learn2learn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def clone_parameters(param_list):
return [p.clone() for p in param_list]


def clone_module(module):
def clone_module(module, memo=None):
"""
[[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/utils.py)
Expand Down Expand Up @@ -91,6 +91,12 @@ def clone_module(module):
# clone = recursive_shallow_copy(model)
# clone._apply(lambda t: t.clone())

if memo is None:
# Maps original data_ptr to the cloned tensor.
# Useful when a Module uses parameters from another Module; see:
# https://github.com/learnables/learn2learn/issues/174
memo = {}

# First, create a copy of the module.
# Adapted from:
# https://github.com/pytorch/pytorch/blob/65bad41cbec096aa767b3752843eddebf845726f/torch/nn/modules/module.py#L1171
Expand All @@ -106,20 +112,36 @@ def clone_module(module):
if hasattr(clone, '_parameters'):
for param_key in module._parameters:
if module._parameters[param_key] is not None:
cloned = module._parameters[param_key].clone()
clone._parameters[param_key] = cloned
param = module._parameters[param_key]
param_ptr = param.data_ptr
if param_ptr in memo:
clone._parameters[param_key] = memo[param_ptr]
else:
cloned = param.clone()
clone._parameters[param_key] = cloned
memo[param_ptr] = cloned

# Third, handle the buffers if necessary
if hasattr(clone, '_buffers'):
for buffer_key in module._buffers:
if clone._buffers[buffer_key] is not None and \
clone._buffers[buffer_key].requires_grad:
clone._buffers[buffer_key] = module._buffers[buffer_key].clone()
buff = module._buffers[buffer_key]
buff_ptr = buff.data_ptr
if buff_ptr in memo:
clone._buffers[buffer_key] = memo[buff_ptr]
else:
cloned = buff.clone()
clone._buffers[buffer_key] = cloned
memo[param_ptr] = cloned

# Then, recurse for each submodule
if hasattr(clone, '_modules'):
for module_key in clone._modules:
clone._modules[module_key] = clone_module(module._modules[module_key])
clone._modules[module_key] = clone_module(
module._modules[module_key],
memo=memo,
)

# Finally, rebuild the flattened parameters for RNNs
# See this issue for more details:
Expand Down
61 changes: 50 additions & 11 deletions tests/unit/algorithms/maml_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ def close(x, y):
class TestMAMLAlgorithm(unittest.TestCase):

def setUp(self):
self.model = torch.nn.Sequential(torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Sigmoid(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Softmax())
self.model = torch.nn.Sequential(
torch.nn.Linear(INPUT_SIZE, HIDDEN_SIZE),
torch.nn.ReLU(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Sigmoid(),
torch.nn.Linear(HIDDEN_SIZE, HIDDEN_SIZE),
torch.nn.Softmax(),
)

self.model.register_buffer('dummy_buf', torch.zeros(1, 2, 3, 4))

Expand Down Expand Up @@ -101,7 +103,11 @@ def test_allow_nograd(self):
try:
# Check that without allow_nograd, adaptation fails
clone.adapt(loss)
self.assertTrue(False, 'adaptation successful despite requires_grad=False') # Check that execution never gets here
# Check that execution never gets here
self.assertTrue(
False,
'adaptation successful despite requires_grad=False',
)
except:
# Check that with allow_nograd, adaptation succeeds
clone.adapt(loss, allow_nograd=True)
Expand All @@ -112,17 +118,50 @@ def test_allow_nograd(self):
if p.requires_grad:
self.assertTrue(p.grad is not None)

maml = l2l.algorithms.MAML(self.model,
lr=INNER_LR,
first_order=False,
allow_nograd=True)
maml = l2l.algorithms.MAML(
self.model,
lr=INNER_LR,
first_order=False,
allow_nograd=True,
)
clone = maml.clone()
loss = sum([p.norm(p=2) for p in clone.parameters()])
# Check that without allow_nograd, adaptation succeeds thanks to init.
orig_weight = self.model[2].weight.clone().detach()
clone.adapt(loss)
self.assertTrue(close(orig_weight, self.model[2].weight))

def test_module_shared_params(self):

class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
cnn = [
torch.nn.Conv2d(3, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
]
self.seq = torch.nn.Sequential(*cnn)
self.head = torch.nn.Sequential(*[
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 100, 3, 2, 1)]
)
self.net = torch.nn.Sequential(self.seq, self.head)

def forward(self, x):
return self.net(x)

module = TestModule()
maml = l2l.algorithms.MAML(module, lr=0.1)
clone = maml.clone()
loss = sum(p.norm(p=2) for p in clone.parameters())
clone.adapt(loss)
loss = sum(p.norm(p=2) for p in clone.parameters())
loss.backward()


if __name__ == '__main__':
Expand Down
45 changes: 44 additions & 1 deletion tests/unit/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def ref_clone_module(module):
each forward call.
See this issue for more details:
https://github.com/learnables/learn2learn/issues/139
Note: This implementation also does not work for Modules that re-use
parameters from another Module.
See this issue for more details:
https://github.com/learnables/learn2learn/issues/174
"""
# First, create a copy of the module.
clone = copy.deepcopy(module)
Expand Down Expand Up @@ -191,10 +196,48 @@ def test_rnn_clone(self):
# Ensure we did better
self.assertTrue(first_loss > second_loss)

def test_module_clone_shared_params(self):
# Tests proper use of memo parameter

class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
cnn = [
torch.nn.Conv2d(3, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
]
self.seq = torch.nn.Sequential(*cnn)
self.head = torch.nn.Sequential(*[
torch.nn.Conv2d(32, 32, 3, 2, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 100, 3, 2, 1)]
)
self.net = torch.nn.Sequential(self.seq, self.head)

def forward(self, x):
return self.net(x)

original = TestModule()
clone = l2l.clone_module(original)
self.assertTrue(
len(list(clone.parameters())) == len(list(original.parameters())),
'clone and original do not have same number of parameters.',
)

orig_params = [p.data_ptr() for p in original.parameters()]
duplicates = [p.data_ptr() in orig_params for p in clone.parameters()]
self.assertTrue(not any(duplicates), 'clone() forgot some parameters.')

def test_module_detach(self):
original_output = self.model(self.input)
original_loss = self.loss_func(original_output, torch.tensor([[0., 0.]]))
original_loss = self.loss_func(
original_output,
torch.tensor([[0., 0.]])
)

original_gradients = torch.autograd.grad(original_loss,
self.model.parameters(),
Expand Down
File renamed without changes.

0 comments on commit 6a32de2

Please sign in to comment.