-
Notifications
You must be signed in to change notification settings - Fork 203
fix: fix temperature-related issues #935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
51133a8 to
761c82f
Compare
68aa949 to
300d1e2
Compare
❌ Submodule Fast-Forward Check FailedCheck based on commit: e9523c2 (PR #935 from ❌ Submodules that need attention:NeMo: ❌ PR branch is BEHIND main branch Please ensure all submodule commits are fast-forwards of the main branch before merging. |
03b3b43 to
67d3db3
Compare
Signed-off-by: Zhanda <[email protected]>
Signed-off-by: Zhanda <[email protected]>
67d3db3 to
4932ede
Compare
Signed-off-by: Zhanda Zhu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
I don't have further comments but I just have a quick question - does it mean that, once vLLM 0.11 is released, we don't need to patch
raw_logprobs = self.compute_logprobs(logits) into
raw_logprobs = self.compute_logprobs(self.apply_temperature(logits.to(torch.float32), sampling_metadata.temperature)) entirely?
Yes! All these logics can be greatly simplified once the |
ℹ️ File Synchronization CheckCheck based on commit: 8769667 (PR #935 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhanda <[email protected]>
0dade0c to
45e571b
Compare
ℹ️ File Synchronization CheckCheck based on commit: 1dd2d46 (PR #935 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Synchronization CheckCheck based on commit: 1dd2d46 (PR #935 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 589d0d8 (PR #935 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
Signed-off-by: Zhanda <[email protected]> Signed-off-by: Zhanda Zhu <[email protected]> Co-authored-by: Shang Wang <[email protected]> Signed-off-by: Qidong Su <[email protected]>
Signed-off-by: Zhanda <[email protected]> Signed-off-by: Zhanda Zhu <[email protected]> Co-authored-by: Shang Wang <[email protected]> Signed-off-by: Stanislav Kirdey <[email protected]>
Signed-off-by: Zhanda <[email protected]> Signed-off-by: Zhanda Zhu <[email protected]> Co-authored-by: Shang Wang <[email protected]> Signed-off-by: Qidong Su <[email protected]>
Signed-off-by: Zhanda <[email protected]> Signed-off-by: Zhanda Zhu <[email protected]> Co-authored-by: Shang Wang <[email protected]>


What does this PR do ?
Fix the temperature-related accuracy issues in both the dtensor path and the mcore path.
As reported in #887 and some other researchers, training GRPO (some variants of Qwen 2.5 7B, AceMath) does not work.
uv run python -u NeMo-Skills/nemo_skills/training/nemo_rl/start_grpo.py \ ++policy.model_name=models/final_hf_checkpoint_1220 \ ++cluster.gpus_per_node=8 \ ++cluster.num_nodes=${NUM_ACTOR_NODES} \ +data.train_data_path=data/deepseek_r1/rl_data/acemath_rl/shuffled_acemath_rl_920k.jsonl \ +data.val_data_path=data/deepseek_r1/rl_data/acemath_rl/shuffled_acemath_rl_920k.jsonl \ ++checkpointing.checkpoint_dir=results/${TASK_NAME}_${EXP_SUFFIX}/checkpoints \ ++logger.log_dir=results/${TASK_NAME}_${EXP_SUFFIX}/training-logs \ ++logger.wandb_enabled=True \ ++logger.wandb.project=${WANDB_PROJECT} \ ++logger.wandb.name=${WANDB_NAME} \ ++logger.wandb.id=${WANDB_NAME} \ ++policy.train_global_batch_size=1024 \ ++policy.train_micro_batch_size=1 \ ++policy.max_total_sequence_length=8192 \ ++policy.logprob_batch_size=4 \ ++policy.sequence_packing.enabled=True \ ++policy.sequence_packing.train_mb_tokens=8192 \ ++policy.sequence_packing.logprob_mb_tokens=8192 \ ++policy.sequence_packing.algorithm="modified_first_fit_decreasing" \ ++policy.sequence_packing.sequence_length_round=64 \ ++policy.generation.vllm_cfg.tensor_parallel_size=4 \ ++policy.generation.vllm_cfg.gpu_memory_utilization=0.8 \ ++policy.generation.vllm_cfg.pipeline_parallel_size=1 \ ++policy.generation.temperature=0.6 \ ++policy.optimizer.kwargs.lr=1e-06 \ ++loss_fn.reference_policy_kl_penalty=0.001 \ ++policy.dtensor_cfg.enabled=True \ ++policy.dtensor_cfg.tensor_parallel_size=2 \ ++policy.dtensor_cfg.sequence_parallel=False \ ++policy.dtensor_cfg.activation_checkpointing=True \ ++data.prompt.prompt_config=qwen/math-cot \ ++data.prompt.prompt_template=qwen-instruct \ ++grpo.num_prompts_per_step=128 \ ++grpo.num_generations_per_prompt=81. Reproduce the errors
As we can see in the images below and the wandb, the rewards decrease a lot after around 50 steps, and the
token_mult_prob_error(meaning) is large.2. Fix the temperature error in the DTensor Policy Worker
We found the reason could be: in this PR #660, it scales wrt the temperature according to the vllm engine version (as vllm V1 engine does not return the final logprobs wrt sampling parameters). This fix is actually wrong. We should follow the pattern in #316 to scale logits according to the temperatures anyway. Otherwise, it may lead to off-policy. The inference and the training behavior do not match.
By removing the
ifinRL/nemo_rl/models/policy/dtensor_policy_worker.py
Lines 471 to 472 in df31c1b
we can see that the reward is back to normal, and
token_mult_prob_erroris much reasonable (though it still has one large spike that is larger than 2).3. Fix the returned logprobs in vLLM V1
We know that the returned logprobs from vLLM V1 is not correct for post-training because it is the raw logits (before applying any sampling paramters). Therefore, the
token_mult_prob_erroris actually an inaccurate indicator. The convergence should not be affected sinceuse_importance_sampling_correctionanduse_on_policy_kl_approximationare False, and the vllm returned logprobs are not involved in the real training (just used in the metrics calculation).To confirm that, we do a file patch over the vLLM to return the logprobs after temperature. The reward curve remains almost the same in case 2, and the
token_mult_prob_erroris almost around 1 (which is expected).Related
Top-p and Top-k fixing are not supported yet. It is not quite elegant to patch vllm to support top-p and top-k as it involves patching multiple files and many code blocks. We have more discussion in #773 and #69.
Issues
Fixes #902, #887