Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] reset_parameters for multiagent nets #1970

Merged
merged 16 commits into from
Feb 27, 2024
52 changes: 50 additions & 2 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse

from numbers import Number

import numpy as np
Expand Down Expand Up @@ -947,6 +948,31 @@ def test_multiagent_mlp_lazy(self):
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
def test_reset_mlp(
self,
n_agents,
centralised,
share_params,
):
actor_net = MultiAgentMLP(
n_agent_inputs=4,
n_agent_outputs=6,
num_cells=[5, 5],
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
)
params_before = actor_net.params.clone()
actor_net.reset_parameters()
params_after = actor_net.params
for p1, p2 in zip(
params_before.values(True, True), params_after.values(True, True)
):
assert not torch.isclose(p1, p2).all()

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
Expand Down Expand Up @@ -1051,6 +1077,30 @@ def test_multiagent_cnn_lazy(self):
if isinstance(p, torch.nn.parameter.UninitializedParameter):
raise AssertionError("UninitializedParameter found")

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize("share_params", [True, False])
@pytest.mark.parametrize("centralised", [True, False])
def test_reset_cnn(
self,
n_agents,
centralised,
share_params,
):
actor_net = MultiAgentConvNet(
in_features=4,
num_cells=[5, 5],
n_agents=n_agents,
centralised=centralised,
share_params=share_params,
)
params_before = actor_net.params.clone()
actor_net.reset_parameters()
params_after = actor_net.params
for p1, p2 in zip(
params_before.values(True, True), params_after.values(True, True)
):
assert not torch.isclose(p1, p2).all()

@pytest.mark.parametrize("n_agents", [1, 3])
@pytest.mark.parametrize(
"batch",
Expand Down Expand Up @@ -1271,7 +1321,6 @@ def test_onlinedtactor(self, batch_dims, T=5):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_lstm_cell(device, bias):

lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias)
lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias)

Expand Down Expand Up @@ -1307,7 +1356,6 @@ def test_python_lstm_cell(device, bias):
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("bias", [True, False])
def test_python_gru_cell(device, bias):

gru_cell1 = GRUCell(10, 20, device=device, bias=bias)
gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias)

Expand Down
30 changes: 25 additions & 5 deletions torchrl/modules/models/multiagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchrl.data.utils import DEVICE_TYPING

from torchrl.modules.models import ConvNet, MLP
from torchrl.modules.models.utils import _reset_parameters_recursive


class MultiAgentNetBase(nn.Module):
Expand All @@ -30,6 +31,7 @@ def __init__(
centralised: bool,
share_params: bool,
agent_dim: int,
vmap_randomness: str = "different",
**kwargs,
):
super().__init__()
Expand All @@ -38,6 +40,7 @@ def __init__(
self.share_params = share_params
self.centralised = centralised
self.agent_dim = agent_dim
self._vmap_randomness = vmap_randomness

agent_networks = [
self._build_single_net(**kwargs)
Expand All @@ -54,9 +57,11 @@ def __init__(
self.__dict__["_empty_net"] = self._build_single_net(**kwargs)

@property
def _vmap_randomness(self):
def vmap_randomness(self):
if self.initialized:
return "error"
return self._vmap_randomness
# Matteo: There seems to be a problem with lazy layers when using "different" here
vmoens marked this conversation as resolved.
Show resolved Hide resolved
# found this bit as legacy, not sure the reason
return "same"

def _make_params(self, agent_networks):
Expand Down Expand Up @@ -92,14 +97,14 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:
if not self.share_params:
if self.centralised:
output = self.vmap_func_module(
self._empty_net, (0, None), (-2,), randomness=self._vmap_randomness
self._empty_net, (0, None), (-2,), randomness=self.vmap_randomness
)(self.params, inputs)
else:
output = self.vmap_func_module(
self._empty_net,
(0, self.agent_dim),
(-2,),
randomness=self._vmap_randomness,
randomness=self.vmap_randomness,
)(self.params, inputs)

# If parameters are shared, agents use the same network
Expand All @@ -125,6 +130,21 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor:

return output

def reset_parameters(self):
def vmap_reset_module(module, *args, **kwargs):
def reset_module(params):
with params.to_module(module):
_reset_parameters_recursive(module)
return params

return torch.vmap(reset_module, *args, **kwargs)

if not self.share_params:
vmap_reset_module(self._empty_net, randomness="different")(self.params)
Copy link
Contributor Author

@matteobettini matteobettini Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to discuss the vmap randomness of this class

    @property
    def _vmap_randomness(self):
        if self.initialized:
            return "error"
        return "same"

Why would this be a class property and why are we having those values?

For me this should be

    @property
    def _vmap_randomness(self):
        return "different"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case it needs to be different to have a different reset values for each agent.

But also in the forward pass I feel like it should be "different"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should be

    @property
    def _vmap_randomness(self):
        if self.initialized:
            return self.vmap_randomness
        return "different"

and users are in charge of telling the module what randomness they want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this a bit? Why do we have a switch on initialization?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For init you need "different" because you must have different weights for each net.
But in other settings you can't tell, and the best is to let the user choose.
They may as well want the same random number for each element of the batch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but what I do not understand is why before we had

    @property
    def _vmap_randomness(self):
        if self.initialized:
            return "error"
        return "same"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I try to change this to

    @property
    def _vmap_randomness(self):
        if self.initialized:
            return self.vmap_randomness
        return "different"

the lazy layers will crash

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the lazy layers will crash

this is a statement that is hard to reproduce, can you share more?

For instance this works fine on my end:

from torchrl.modules import MLP
from tensordict import TensorDict
import torch
from functorch import dim

d0 = dim.dims(1)
modules = [torch.nn.Linear(2, 3) for _ in range(3)]

td = TensorDict.from_modules(*modules, as_module=True)

def reset(td):
    with td.to_module(modules[0]):
        modules[0].reset_parameters()
    return td

td = torch.vmap(reset, randomness="same")(td)
print(td["weight"])
td = torch.vmap(reset, randomness="different")(td)
print(td["weight"])

the first produces a stack of identical tensors, the second different

else:
with self.params.to_module(self._empty_net):
_reset_parameters_recursive(self._empty_net)


class MultiAgentMLP(MultiAgentNetBase):
"""Mult-agent MLP.
Expand Down Expand Up @@ -262,7 +282,6 @@ def __init__(
activation_class: Optional[Type[nn.Module]] = nn.Tanh,
**kwargs,
):

self.n_agents = n_agents
self.n_agent_inputs = n_agent_inputs
self.n_agent_outputs = n_agent_outputs
Expand Down Expand Up @@ -477,6 +496,7 @@ def __init__(
share_params=share_params,
device=device,
agent_dim=-4,
**kwargs,
)

def _build_single_net(self, *, device, **kwargs):
Expand Down
25 changes: 25 additions & 0 deletions torchrl/modules/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import inspect
import warnings
from typing import Optional, Sequence, Type

import torch
Expand Down Expand Up @@ -123,3 +124,27 @@ def create_on_device(
else:
return module_class(*args, **kwargs).to(device)
# .to() is always available for nn.Module, and does nothing if the Module contains no parameters or buffers


def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool:
"""Recursively resets the parameters of a :class:`~torch.nn.Module` in-place.

Args:
module (torch.nn.Module): the module to reset.
warn_if_no_op (bool, optional): whether to raise a warning in case this is a no-op.
Defaults to ``True``.

Returns: whether any parameter has been reset.

"""
any_reset = False
for layer in module.children():
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
any_reset |= True
any_reset |= _reset_parameters_recursive(layer, warn_if_no_op=False)
if warn_if_no_op and not any_reset:
warnings.warn(
"_reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset"
)
return any_reset
Loading