Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Reset parameters of multiagent networks #1967

Closed
matteobettini opened this issue Feb 26, 2024 · 4 comments · Fixed by #1970
Closed

[Feature] Reset parameters of multiagent networks #1967

matteobettini opened this issue Feb 26, 2024 · 4 comments · Fixed by #1970
Assignees
Labels
enhancement New feature or request

Comments

@matteobettini
Copy link
Contributor

Hello!

So I usually used a function like this to reset parameters of the multiagent networks

def reset_child_params(module):
    for layer in module.children():
        if hasattr(layer, "reset_parameters"):
            layer.reset_parameters()
        reset_child_params(layer)

After #1921 this seems to have no effect.

Is there a suggested way to reset the parameters?

Thanks!

@matteobettini matteobettini added the enhancement New feature or request label Feb 26, 2024
@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 26, 2024

Even when resetting through the dedicated tensordict function it does not work

from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.modules.models.multiagent import MultiAgentMLP

if __name__ == "__main__":
    actor_net = MultiAgentMLP(
        n_agent_inputs=4,
        n_agent_outputs=6,
        n_agents=2,
        centralised=False,
        share_params=False,
        device="cpu",
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh,
    )

    policy_module = TensorDictModule(
        actor_net,
        in_keys=[("agents", "observation")],
        out_keys=[("agents", "action")],
    )
    params_before = list(policy_module.parameters())
    policy_module.reset_parameters_recursive()
    params_after = list(policy_module.parameters())
    for p1, p2 in zip(params_before, params_after):
        assert (p1 != p2).all()

@vmoens
Copy link
Contributor

vmoens commented Feb 26, 2024

That sounds like something we should support! I have a limited bandwidth and that doesn't seem very complex so feel free to submit a PR if you need this (semi-)urgently

@matteobettini
Copy link
Contributor Author

Do we have any insights of why policy_module.reset_parameters_recursive() runs without error but does not apply the reset?

@matteobettini matteobettini changed the title [Feature Request] Reset parameters of multiagent networks [BUG] Reset parameters of multiagent networks Feb 26, 2024
@matteobettini
Copy link
Contributor Author

matteobettini commented Feb 26, 2024

From what I investigated it seems like when reset_parameters_recursive is called it looks for an nn.Module with the reset_parameters() function. It will not find anything with TensorDictParams

I made a PR to warn when this is a no-op

Will make another PR to implement reset_parameters() for the multiagent nets

@matteobettini matteobettini changed the title [BUG] Reset parameters of multiagent networks [Feature] Reset parameters of multiagent networks Feb 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants