Skip to content

[constant scheduler] fix: model won't be updated on first training step#1463

Merged
eric-haibin-lin merged 3 commits intoverl-project:mainfrom
0x404:fix_lr
May 21, 2025
Merged

[constant scheduler] fix: model won't be updated on first training step#1463
eric-haibin-lin merged 3 commits intoverl-project:mainfrom
0x404:fix_lr

Conversation

@0x404
Copy link
Copy Markdown
Collaborator

@0x404 0x404 commented May 9, 2025

What does this PR do?

I found that when using the FSDP checkpoint test introduced by #1288, after one step of training, both comparisons pass the tests. This includes comparing the merged FSDP checkpoint with the verl-saved HF model, and comparing the merged FSDP checkpoint with the original HuggingFace model. This means the FSDP model is not being updated after one step of training.

However, the training log shows that in the first step, the learning rate is 1e-6, which is weird. I found two issues in the existing code:

  1. There's a problem with get_constant_schedule_with_warmup: when num_warmup_steps=0, at the first step (current_step=0), it will return 0.0 instead of 1.0. This is wrong and inconsistent with the existing constant LR definition: https://github.com/huggingface/transformers/blob/774dc274ac966f4bccbcd90d55bba23f6cca37ae/src/transformers/optimization.py#L72

  2. The log saves the learning rate after actor_lr_scheduler.step(), which is incorrect since it records the next step's LR, thus hiding the problem with get_constant_schedule_with_warmup.

This PR fixes these issues and decrease tolerance in FSDP checkpoint test.

Additional Info.

  • Training: FSDP
  • Inference: both

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if neccessary.

0x404 added 2 commits May 9, 2025 11:03
- Fix LR scheduler step timing to properly record and apply learning rate
- Correct warmup scheduler implementation to maintain constant rate after warmup
- Increase learning rate in test script for better checkpoint validation
@0x404
Copy link
Copy Markdown
Collaborator Author

0x404 commented May 21, 2025

Hi @eric-haibin-lin, Could you re-review this, just resolve conflicts several days ago

@eric-haibin-lin eric-haibin-lin merged commit 80af51b into verl-project:main May 21, 2025
34 checks passed
cedricbeta pushed a commit to cedricbeta/verl that referenced this pull request May 21, 2025
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request May 22, 2025
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants