Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] TorchPolicy.set_state move extra parameters in optimizer to torch.Tensor #28653

Open
sychen-alternative opened this issue Sep 20, 2022 · 0 comments
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues

Comments

@sychen-alternative
Copy link

sychen-alternative commented Sep 20, 2022

What happened + What you expected to happen

This is not a contribution.

When handling optimizer state,

TorchPolicy.get_state converts all torch.Tensor to numpy.ndarray. TorchPolicy.set_state is supposed to convert them back. However, it accidently converts extra parameters, such as lr, betas, step, and etc, into torch.Tensor on self.device. The actual function that does the conversion is ray.rllib.utils.torch_utils.convert_to_torch_tensor

When self.device == “cuda”, this will lead to significant gradient update speed drop during training.

Versions / Dependencies

ray 1.12.1. Based on the code on master, this is still an issue.
pytorch

Reproduction script

# Import the RL algorithm (Trainer) we would like to use.
from ray.rllib.agents.ppo import PPOTrainer

# Configure the algorithm.
config = {
    "num_gpus": 1.,
    # Environment (RLlib understands openAI gym registered strings).
    "env": "Taxi-v3",
    # Use 2 environment workers (aka "rollout workers") that parallelly
    # collect samples from their own environment clone(s).
    "num_workers": 2,
    # Change this to "framework: torch", if you are using PyTorch.
    # Also, use "framework: tf2" for tf2.x eager execution.
    "framework": "torch",
    # Tweak the default model provided automatically by RLlib,
    # given the environment's observation- and action spaces.
    "model": {
        "fcnet_hiddens": [64, 64],
        "fcnet_activation": "relu",
    },
    # Set up a separate evaluation worker set for the
    # `trainer.evaluate()` call after training (see below).
    "evaluation_num_workers": 1,
    # Only for evaluation runs, render the env.
    "evaluation_config": {
        "render_env": True,
    },
}

# Create our RLlib Trainer.
trainer = PPOTrainer(config=config)

print(trainer.get_policy()._optimizers[0].state_dict())
state = trainer.get_policy().get_state()
trainer.get_policy().set_state(state)
print(trainer.get_policy()._optimizers[0].state_dict())

# Output:
# {'state': {}, 'param_groups': [{'lr': 5e-05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}]}
# {'state': {}, 'param_groups': [{'lr': tensor(5.0000e-05, device='cuda:0'), 'betas': (tensor(0.9000, device='cuda:0'), tensor(0.9990, device='cuda:0')), 'eps': tensor(1.0000e-08, device='cuda:0'), 'weight_decay': tensor(0, device='cuda:0'), 'amsgrad': tensor(False, device='cuda:0'), 'maximize': tensor(False, device='cuda:0'), 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}]}

Issue Severity

Medium: It is a significant difficulty but I can work around it.

@sychen-alternative sychen-alternative added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Sep 20, 2022
@richardliaw richardliaw added the rllib RLlib related issues label Oct 7, 2022
@kouroshHakha kouroshHakha added P1 Issue that should be fixed within a few weeks and removed triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Oct 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't P1 Issue that should be fixed within a few weeks rllib RLlib related issues
Projects
None yet
Development

No branches or pull requests

3 participants