diff --git a/test/test_modules.py b/test/test_modules.py index 3d01fd04768..94c8a809170 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse + from numbers import Number import numpy as np @@ -859,7 +860,7 @@ def _get_mock_input_td( @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) @pytest.mark.parametrize("n_agent_inputs", [6, None]) - @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) def test_multiagent_mlp( self, n_agents, @@ -923,7 +924,7 @@ def test_multiagent_mlp_lazy(self): share_params=False, depth=2, ) - optim = torch.optim.Adam(mlp.parameters()) + optim = torch.optim.SGD(mlp.parameters()) for p in mlp.parameters(): if isinstance(p, torch.nn.parameter.UninitializedParameter): break @@ -938,6 +939,11 @@ def test_multiagent_mlp_lazy(self): td = self._get_mock_input_td(3, 4, batch=(10,)) obs = td.get(("agents", "observation")) out = mlp(obs) + assert ( + not mlp.params[0] + .apply(lambda x, y: torch.isclose(x, y), mlp.params[1]) + .any() + ) out.mean().backward() optim.step() for p in mlp.parameters(): @@ -947,11 +953,41 @@ def test_multiagent_mlp_lazy(self): if isinstance(p, torch.nn.parameter.UninitializedParameter): raise AssertionError("UninitializedParameter found") + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralised", [True, False]) + def test_multiagent_reset_mlp( + self, + n_agents, + centralised, + share_params, + ): + actor_net = MultiAgentMLP( + n_agent_inputs=4, + n_agent_outputs=6, + num_cells=(4, 4), + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + ) + params_before = actor_net.params.clone() + actor_net.reset_parameters() + params_after = actor_net.params + assert not params_before.apply( + lambda x, y: torch.isclose(x, y), params_after, batch_size=[] + ).any() + if params_after.numel() > 1: + assert ( + not params_after[0] + .apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[]) + .any() + ) + @pytest.mark.parametrize("n_agents", [1, 3]) @pytest.mark.parametrize("share_params", [True, False]) @pytest.mark.parametrize("centralised", [True, False]) @pytest.mark.parametrize("channels", [3, None]) - @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) + @pytest.mark.parametrize("batch", [(4,), (4, 3), ()]) def test_multiagent_cnn( self, n_agents, @@ -959,8 +995,8 @@ def test_multiagent_cnn( share_params, batch, channels, - x=50, - y=50, + x=15, + y=15, ): torch.manual_seed(0) cnn = MultiAgentConvNet( @@ -968,6 +1004,7 @@ def test_multiagent_cnn( centralised=centralised, share_params=share_params, in_features=channels, + kernel_sizes=3, ) if channels is None: channels = 3 @@ -983,21 +1020,20 @@ def test_multiagent_cnn( obs = td[("agents", "observation")] out = cnn(obs) assert out.shape[:-1] == (*batch, n_agents) - for i in range(n_agents): - if centralised and share_params: - assert torch.allclose(out[..., i, :], out[..., 0, :]) - else: + if centralised and share_params: + torch.testing.assert_close(out, out[..., :1, :].expand_as(out)) + else: + for i in range(n_agents): for j in range(i + 1, n_agents): assert not torch.allclose(out[..., i, :], out[..., j, :]) - obs[..., 0, 0, 0, 0] += 1 out2 = cnn(obs) - for i in range(n_agents): - if centralised: - # a modification to the input of agent 0 will impact all agents - assert not torch.allclose(out[..., i, :], out2[..., i, :]) - elif i > 0: - assert torch.allclose(out[..., i, :], out2[..., i, :]) + if centralised: + # a modification to the input of agent 0 will impact all agents + assert not torch.isclose(out, out2).all() + elif n_agents > 1: + assert not torch.isclose(out[..., 0, :], out2[..., 0, :]).all() + torch.testing.assert_close(out[..., 1:, :], out2[..., 1:, :]) obs = torch.randn(*batch, 1, channels, x, y).expand( *batch, n_agents, channels, x, y @@ -1013,13 +1049,16 @@ def test_multiagent_cnn( assert not torch.allclose(out[..., i, :], out[..., j, :]) def test_multiagent_cnn_lazy(self): + n_agents = 5 + n_channels = 3 cnn = MultiAgentConvNet( - n_agents=5, + n_agents=n_agents, centralised=False, share_params=False, in_features=None, + kernel_sizes=3, ) - optim = torch.optim.Adam(cnn.parameters()) + optim = torch.optim.SGD(cnn.parameters()) for p in cnn.parameters(): if isinstance(p, torch.nn.parameter.UninitializedParameter): break @@ -1034,14 +1073,19 @@ def test_multiagent_cnn_lazy(self): td = TensorDict( { "agents": TensorDict( - {"observation": torch.randn(10, 5, 3, 50, 50)}, - [10, 5], + {"observation": torch.randn(4, n_agents, n_channels, 15, 15)}, + [4, 5], ) }, - batch_size=[10], + batch_size=[4], ) obs = td[("agents", "observation")] out = cnn(obs) + assert ( + not cnn.params[0] + .apply(lambda x, y: torch.isclose(x, y), cnn.params[1]) + .any() + ) out.mean().backward() optim.step() for p in cnn.parameters(): @@ -1052,17 +1096,36 @@ def test_multiagent_cnn_lazy(self): raise AssertionError("UninitializedParameter found") @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize( - "batch", - [ - (10,), - ( - 10, - 3, - ), - (), - ], - ) + @pytest.mark.parametrize("share_params", [True, False]) + @pytest.mark.parametrize("centralised", [True, False]) + def test_multiagent_reset_cnn( + self, + n_agents, + centralised, + share_params, + ): + actor_net = MultiAgentConvNet( + in_features=4, + num_cells=[5, 5], + n_agents=n_agents, + centralised=centralised, + share_params=share_params, + ) + params_before = actor_net.params.clone() + actor_net.reset_parameters() + params_after = actor_net.params + assert not params_before.apply( + lambda x, y: torch.isclose(x, y), params_after, batch_size=[] + ).any() + if params_after.numel() > 1: + assert ( + not params_after[0] + .apply(lambda x, y: torch.isclose(x, y), params_after[1], batch_size=[]) + .any() + ) + + @pytest.mark.parametrize("n_agents", [1, 3]) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) def test_vdn(self, n_agents, batch): torch.manual_seed(0) mixer = VDNMixer(n_agents=n_agents, device="cpu") @@ -1075,17 +1138,7 @@ def test_vdn(self, n_agents, batch): assert torch.equal(obs.sum(-2), out) @pytest.mark.parametrize("n_agents", [1, 3]) - @pytest.mark.parametrize( - "batch", - [ - (10,), - ( - 10, - 3, - ), - (), - ], - ) + @pytest.mark.parametrize("batch", [(10,), (10, 3), ()]) @pytest.mark.parametrize("state_shape", [(64, 64, 3), (10,)]) def test_qmix(self, n_agents, batch, state_shape): torch.manual_seed(0) @@ -1271,7 +1324,6 @@ def test_onlinedtactor(self, batch_dims, T=5): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("bias", [True, False]) def test_python_lstm_cell(device, bias): - lstm_cell1 = LSTMCell(10, 20, device=device, bias=bias) lstm_cell2 = nn.LSTMCell(10, 20, device=device, bias=bias) @@ -1307,7 +1359,6 @@ def test_python_lstm_cell(device, bias): @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("bias", [True, False]) def test_python_gru_cell(device, bias): - gru_cell1 = GRUCell(10, 20, device=device, bias=bias) gru_cell2 = nn.GRUCell(10, 20, device=device, bias=bias) diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 6229aa30fe3..46d4c0de20b 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -16,6 +16,7 @@ from torchrl.data.utils import DEVICE_TYPING from torchrl.modules.models import ConvNet, MLP +from torchrl.modules.models.utils import _reset_parameters_recursive class MultiAgentNetBase(nn.Module): @@ -30,6 +31,7 @@ def __init__( centralised: bool, share_params: bool, agent_dim: int, + vmap_randomness: str = "different", **kwargs, ): super().__init__() @@ -38,6 +40,7 @@ def __init__( self.share_params = share_params self.centralised = centralised self.agent_dim = agent_dim + self._vmap_randomness = vmap_randomness agent_networks = [ self._build_single_net(**kwargs) @@ -54,9 +57,13 @@ def __init__( self.__dict__["_empty_net"] = self._build_single_net(**kwargs) @property - def _vmap_randomness(self): + def vmap_randomness(self): if self.initialized: - return "error" + return self._vmap_randomness + # The class _BatchedUninitializedParameter and buffer are not batched + # by vmap so using "different" will raise an exception because vmap can't find + # the batch dimension. This is ok though since we won't have the same config + # for every element (as one might expect from "same"). return "same" def _make_params(self, agent_networks): @@ -92,14 +99,14 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: if not self.share_params: if self.centralised: output = self.vmap_func_module( - self._empty_net, (0, None), (-2,), randomness=self._vmap_randomness + self._empty_net, (0, None), (-2,), randomness=self.vmap_randomness )(self.params, inputs) else: output = self.vmap_func_module( self._empty_net, (0, self.agent_dim), (-2,), - randomness=self._vmap_randomness, + randomness=self.vmap_randomness, )(self.params, inputs) # If parameters are shared, agents use the same network @@ -125,6 +132,23 @@ def forward(self, *inputs: Tuple[torch.Tensor]) -> torch.Tensor: return output + def reset_parameters(self): + """Resets the parameters of the model.""" + + def vmap_reset_module(module, *args, **kwargs): + def reset_module(params): + with params.to_module(module): + _reset_parameters_recursive(module) + return params + + return torch.vmap(reset_module, *args, **kwargs) + + if not self.share_params: + vmap_reset_module(self._empty_net, randomness="different")(self.params) + else: + with self.params.to_module(self._empty_net): + _reset_parameters_recursive(self._empty_net) + class MultiAgentMLP(MultiAgentNetBase): """Mult-agent MLP. @@ -262,7 +286,6 @@ def __init__( activation_class: Optional[Type[nn.Module]] = nn.Tanh, **kwargs, ): - self.n_agents = n_agents self.n_agent_inputs = n_agent_inputs self.n_agent_outputs = n_agent_outputs @@ -477,6 +500,7 @@ def __init__( share_params=share_params, device=device, agent_dim=-4, + **kwargs, ) def _build_single_net(self, *, device, **kwargs): diff --git a/torchrl/modules/models/utils.py b/torchrl/modules/models/utils.py index b4fa7eb58fd..392a4ec4376 100644 --- a/torchrl/modules/models/utils.py +++ b/torchrl/modules/models/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import inspect +import warnings from typing import Optional, Sequence, Type import torch @@ -123,3 +124,27 @@ def create_on_device( else: return module_class(*args, **kwargs).to(device) # .to() is always available for nn.Module, and does nothing if the Module contains no parameters or buffers + + +def _reset_parameters_recursive(module, warn_if_no_op: bool = True) -> bool: + """Recursively resets the parameters of a :class:`~torch.nn.Module` in-place. + + Args: + module (torch.nn.Module): the module to reset. + warn_if_no_op (bool, optional): whether to raise a warning in case this is a no-op. + Defaults to ``True``. + + Returns: whether any parameter has been reset. + + """ + any_reset = False + for layer in module.children(): + if hasattr(layer, "reset_parameters"): + layer.reset_parameters() + any_reset |= True + any_reset |= _reset_parameters_recursive(layer, warn_if_no_op=False) + if warn_if_no_op and not any_reset: + warnings.warn( + "_reset_parameters_recursive was called without the parameters argument and did not find any parameters to reset" + ) + return any_reset