Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] PPO algorithm can't be trained from checkpoint #50136

Open
kronion opened this issue Jan 30, 2025 · 0 comments
Open

[RLlib] PPO algorithm can't be trained from checkpoint #50136

kronion opened this issue Jan 30, 2025 · 0 comments
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues triage Needs triage (eg: priority, bug/not-bug, and owning component)

Comments

@kronion
Copy link

kronion commented Jan 30, 2025

What happened + What you expected to happen

What happened?
I can create a PPO algorithm and train it like so:

ppo_config = (
    ppo.PPOConfig()
        .framework("torch")
        .environment("custom_env", env_config=env_config)
        .rl_module(
            model_config=model_config,
            rl_module_spec=RLModuleSpec(
                module_class=ActionMaskingTorchRLModule,
            )
        )
        .env_runners(**env_runner_config)
        .training(**rl_config)
        .learners(**resources_config)
)

algo = ppo_config.build_algo()

algo.train()
path = Path("./tmp_checkpoints").absolute()
algo.save_to_path(path)

If I restore the checkpoint and try to resume training, I get an error:

algo = Algorithm.from_checkpoint(path)
algo.train()
[traceback shortened for readability]
...
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/ray/rllib/core/learner/torch/torch_learner.py", line 252, in apply_gradients                                                                                         
    optim.step()                                                                                                                         
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 487, in wrapper                                                                                                                      
    out = func(*args, **kwargs)                                     
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 91, in _use_grad                                                                                                                     
    ret = func(self, *args, **kwargs)                                                                                                    
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/torch/optim/adam.py", line 223, in step                                                                                                                              
    adam(                                                           
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/torch/optim/optimizer.py", line 154, in maybe_fallback                                                                                                               
    return func(*args, **kwargs)                                                                                                         
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/torch/optim/adam.py", line 784, in adam                                                                                                                              
    func(                                                                                                                                
  File "/home/kronion/.cache/pypoetry/virtualenvs/env-xAjDonli-py3.10/lib/python3.10/site-packages/torch/optim/adam.py", line 543, in _multi_tensor_adam                                                                                                                
    torch._foreach_addcmul_(                                                                                                             
RuntimeError: Expected scalars to be on CPU, got cuda:0 instead.  

What I expected
I expected to be able to resume training from a checkpoint without issue.

Additional context
It appears that the optimizer's betas are supposed to be normal floats, but they get converted to tensors during the checkpoint restore.

Constructing the algo from scratch (this trains just fine):

> algo = ppo_config.build_algo()
> algo.learner_group._learner._named_optimizers
{'default_policy_default_optimizer': Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 1e-05
    maximize: False
    weight_decay: 0
)}

Vs restoring the algo from a checkpoint:

> algo = Algorithm.from_checkpoint(path)
> algo.learner_group._learner._named_optimizers
{'default_policy_default_optimizer': Adam (
Parameter Group 0
    amsgrad: False
    betas: (tensor(0.9000, device='cuda:0'), tensor(0.9990, device='cuda:0'))
    capturable: False
    differentiable: False
    eps: 9.99999993922529e-09
    foreach: None
    fused: None
    lr: 9.999999747378752e-06
    maximize: False
    weight_decay: 0
)}

I traced through the checkpoint restoration code. The betas are unpickled as floats, but they're cast to tensors in rllib/core/learner/torch/torch_learner.py:

328         @override(Learner)                                                                                                                                                                                                                                                     
329         def _set_optimizer_state(self, state: StateDict) -> None:                                                                                                                                                                                                              
330             for name, state_dict in state.items():                                                                                                                                                                                                                             
331                 # Ignore updating optimizers matching to submodules not present in this                                                                                                                                                                                        
332                 # Learner's MultiRLModule.                                                                                                                                                                                                                                     
333                 module_id = state_dict["module_id"]                                                                                                                                                                                                                            
334                 if name not in self._named_optimizers and module_id in self.module:                                                                                                                                                                                            
335                     self.configure_optimizers_for_module(                                                                                                                                                                                                                      
336                         module_id=module_id,                                                                                                                                                                                                                                   
337                         config=self.config.get_config_for_module(module_id=module_id),                                                                                                                                                                                         
338                     )                                                                                                                                                                                                                                                          
339                 if name in self._named_optimizers:                                                                                                                                                                                                                             
340                     self._named_optimizers[name].load_state_dict(                                                                                                                                                                                                              
341  ->                     convert_to_torch_tensor(state_dict["state"], device=self._device)                                                                                                                                                                                      
342                     ) 

Looking at the relevant piece of state_dict["state"], you can see the values were initially floats:

{..., 'param_groups': [{'lr': 1e-05, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad':
 False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]}

I'm not sure how this issue should be resolved. Perhaps certain optimizer parameters should be cast back to scalars in load_state_dict()?

I'm also wondering why this isn't a common issue for other users (though I did find one similar bug report, see #28653).

Versions / Dependencies

Ray 2.41.0
torch 2.5.1+cu121
Python 3.10.12

Reproduction script

I need more time to create a standalone reproduction script.

Issue Severity

High: It blocks me from completing my task.

@kronion kronion added bug Something that is supposed to be working; but isn't triage Needs triage (eg: priority, bug/not-bug, and owning component) labels Jan 30, 2025
@jcotant1 jcotant1 added the rllib RLlib related issues label Jan 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something that is supposed to be working; but isn't rllib RLlib related issues triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

No branches or pull requests

2 participants