Skip to content

Conversation

@zhandaz
Copy link
Contributor

@zhandaz zhandaz commented Aug 18, 2025

What does this PR do ?

Fix the temperature-related accuracy issues in both the dtensor path and the mcore path.

As reported in #887 and some other researchers, training GRPO (some variants of Qwen 2.5 7B, AceMath) does not work.

uv run python -u NeMo-Skills/nemo_skills/training/nemo_rl/start_grpo.py \
    ++policy.model_name=models/final_hf_checkpoint_1220 \
    ++cluster.gpus_per_node=8 \
    ++cluster.num_nodes=${NUM_ACTOR_NODES} \
    +data.train_data_path=data/deepseek_r1/rl_data/acemath_rl/shuffled_acemath_rl_920k.jsonl \
    +data.val_data_path=data/deepseek_r1/rl_data/acemath_rl/shuffled_acemath_rl_920k.jsonl \
    ++checkpointing.checkpoint_dir=results/${TASK_NAME}_${EXP_SUFFIX}/checkpoints \
    ++logger.log_dir=results/${TASK_NAME}_${EXP_SUFFIX}/training-logs \
    ++logger.wandb_enabled=True \
    ++logger.wandb.project=${WANDB_PROJECT} \
    ++logger.wandb.name=${WANDB_NAME} \
    ++logger.wandb.id=${WANDB_NAME} \
    ++policy.train_global_batch_size=1024 \
    ++policy.train_micro_batch_size=1 \
    ++policy.max_total_sequence_length=8192 \
    ++policy.logprob_batch_size=4 \
    ++policy.sequence_packing.enabled=True \
    ++policy.sequence_packing.train_mb_tokens=8192 \
    ++policy.sequence_packing.logprob_mb_tokens=8192 \
    ++policy.sequence_packing.algorithm="modified_first_fit_decreasing" \
    ++policy.sequence_packing.sequence_length_round=64 \
    ++policy.generation.vllm_cfg.tensor_parallel_size=4 \
    ++policy.generation.vllm_cfg.gpu_memory_utilization=0.8 \
    ++policy.generation.vllm_cfg.pipeline_parallel_size=1 \
    ++policy.generation.temperature=0.6 \
    ++policy.optimizer.kwargs.lr=1e-06 \
    ++loss_fn.reference_policy_kl_penalty=0.001 \
    ++policy.dtensor_cfg.enabled=True \
    ++policy.dtensor_cfg.tensor_parallel_size=2 \
    ++policy.dtensor_cfg.sequence_parallel=False \
    ++policy.dtensor_cfg.activation_checkpointing=True \
    ++data.prompt.prompt_config=qwen/math-cot \
    ++data.prompt.prompt_template=qwen-instruct \
    ++grpo.num_prompts_per_step=128 \
    ++grpo.num_generations_per_prompt=8

1. Reproduce the errors

As we can see in the images below and the wandb, the rewards decrease a lot after around 50 steps, and the token_mult_prob_error (meaning) is large.

Image

2. Fix the temperature error in the DTensor Policy Worker

We found the reason could be: in this PR #660, it scales wrt the temperature according to the vllm engine version (as vllm V1 engine does not return the final logprobs wrt sampling parameters). This fix is actually wrong. We should follow the pattern in #316 to scale logits according to the temperatures anyway. Otherwise, it may lead to off-policy. The inference and the training behavior do not match.

By removing the if in

if not is_vllm_v1_engine_enabled():
logits.div_(self.cfg["generation"]["temperature"])

we can see that the reward is back to normal, and token_mult_prob_error is much reasonable (though it still has one large spike that is larger than 2).

Image

3. Fix the returned logprobs in vLLM V1

We know that the returned logprobs from vLLM V1 is not correct for post-training because it is the raw logits (before applying any sampling paramters). Therefore, the token_mult_prob_error is actually an inaccurate indicator. The convergence should not be affected since use_importance_sampling_correction and use_on_policy_kl_approximation are False, and the vllm returned logprobs are not involved in the real training (just used in the metrics calculation).

To confirm that, we do a file patch over the vLLM to return the logprobs after temperature. The reward curve remains almost the same in case 2, and the token_mult_prob_error is almost around 1 (which is expected).

Image

Related

Top-p and Top-k fixing are not supported yet. It is not quite elegant to patch vllm to support top-p and top-k as it involves patching multiple files and many code blocks. We have more discussion in #773 and #69.

Issues

Fixes #902, #887

@zhandaz zhandaz requested a review from parthchadha August 18, 2025 14:47
@zhandaz zhandaz self-assigned this Aug 18, 2025
@zhandaz zhandaz requested a review from wangshangsam August 18, 2025 14:48
@zhandaz zhandaz changed the title fix: fix temperature related issues in the DTensor path fix: fix temperature-related issues in the DTensor path Aug 18, 2025
@zhandaz zhandaz force-pushed the zhanda/debug-accuracy branch from 51133a8 to 761c82f Compare August 18, 2025 14:51
parthchadha
parthchadha previously approved these changes Aug 18, 2025
@wangshangsam wangshangsam linked an issue Aug 18, 2025 that may be closed by this pull request
wangshangsam
wangshangsam previously approved these changes Aug 18, 2025
@zhandaz zhandaz dismissed stale reviews from wangshangsam and parthchadha via 68aa949 August 19, 2025 18:26
@zhandaz zhandaz force-pushed the zhanda/debug-accuracy branch from 68aa949 to 300d1e2 Compare August 19, 2025 18:26
@github-actions
Copy link

❌ Submodule Fast-Forward Check Failed

Check based on commit: e9523c2 (PR #935 from zhanda/debug-accuracy)

❌ Submodules that need attention:

NeMo: ❌ PR branch is BEHIND main branch
TARGET (main branch): https://github.com/NVIDIA/NeMo/commits/5c42641e344a487c7ca5b253a7483f0af8ef40e6/
CURRENT (PR #935 from zhanda/debug-accuracy): https://github.com/NVIDIA/NeMo/commits/aaefedd1d13f4ccd5cd06a19e06f1df33589a235/

Please ensure all submodule commits are fast-forwards of the main branch before merging.

@zhandaz zhandaz force-pushed the zhanda/debug-accuracy branch from 03b3b43 to 67d3db3 Compare August 19, 2025 20:20
@github-actions github-actions bot added documentation Improvements or additions to documentation CI Relating to CI labels Aug 19, 2025
@zhandaz zhandaz force-pushed the zhanda/debug-accuracy branch from 67d3db3 to 4932ede Compare August 19, 2025 20:30
@github-actions github-actions bot removed documentation Improvements or additions to documentation CI Relating to CI labels Aug 19, 2025
parthchadha
parthchadha previously approved these changes Aug 19, 2025
@zhandaz zhandaz requested a review from wangshangsam August 19, 2025 21:40
@zhandaz
Copy link
Contributor Author

zhandaz commented Aug 20, 2025

Have rerun the experiments after merging the main. There is a small spike. Should be fine.

If there is no further comment, should be ready to merge.

image

wangshangsam
wangshangsam previously approved these changes Aug 20, 2025
Copy link
Contributor

@wangshangsam wangshangsam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

I don't have further comments but I just have a quick question - does it mean that, once vLLM 0.11 is released, we don't need to patch
raw_logprobs = self.compute_logprobs(logits) into
raw_logprobs = self.compute_logprobs(self.apply_temperature(logits.to(torch.float32), sampling_metadata.temperature)) entirely?

@zhandaz
Copy link
Contributor Author

zhandaz commented Aug 20, 2025

I don't have further comments but I just have a quick question - does it mean that, once vLLM 0.11 is released, we don't need to patch raw_logprobs = self.compute_logprobs(logits) into raw_logprobs = self.compute_logprobs(self.apply_temperature(logits.to(torch.float32), sampling_metadata.temperature)) entirely?

Yes! All these logics can be greatly simplified once the final_logprobs can be returned by vllm. Hopefully it is available in vllm=0.11.0.

@terrykong terrykong added this pull request to the merge queue Aug 20, 2025
@github-actions
Copy link

ℹ️ File Synchronization Check

Check based on commit: 8769667 (PR #935 from zhanda/debug-accuracy)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zhandaz zhandaz force-pushed the zhanda/debug-accuracy branch from 0dade0c to 45e571b Compare August 25, 2025 20:45
@zhandaz zhandaz changed the title fix: fix temperature-related issues in the DTensor path fix: fix temperature-related issues Aug 25, 2025
@NVIDIA-NeMo NVIDIA-NeMo deleted a comment from github-actions bot Aug 25, 2025
@NVIDIA-NeMo NVIDIA-NeMo deleted a comment from github-actions bot Aug 25, 2025
@github-actions
Copy link

ℹ️ File Synchronization Check

Check based on commit: 1dd2d46 (PR #935 from zhanda/debug-accuracy)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zhandaz zhandaz added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Aug 25, 2025
@github-actions
Copy link

ℹ️ File Synchronization Check

Check based on commit: 1dd2d46 (PR #935 from zhanda/debug-accuracy)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@zhandaz
Copy link
Contributor Author

zhandaz commented Aug 26, 2025

The reward seems similar for mcore path after this patch. However, the token_mult_prob_error issue is not mitigated much in this case. I am testing with deepseek-ai/DeepSeek-R1-Distill-Qwen-7B.

image

@github-actions
Copy link

ℹ️ File Consistency Check

Check based on commit: 589d0d8 (PR #935 from zhanda/debug-accuracy)

✅ DTensor Policy Worker Synchronization Check

Both DTensor policy worker files were modified in this PR:

  • nemo_rl/models/policy/dtensor_policy_worker.py
  • nemo_rl/models/policy/dtensor_policy_worker_v2.py

Please ensure that the changes are consistent between both files where applicable.


This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

@terrykong terrykong added this pull request to the merge queue Aug 26, 2025
Merged via the queue into main with commit 2aea5ad Aug 27, 2025
37 of 38 checks passed
@terrykong terrykong deleted the zhanda/debug-accuracy branch August 27, 2025 00:09
soodoshll pushed a commit to soodoshll/RL that referenced this pull request Aug 28, 2025
Signed-off-by: Zhanda <[email protected]>
Signed-off-by: Zhanda Zhu <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Signed-off-by: Qidong Su <[email protected]>
skirdey-inflection pushed a commit to skirdey-inflection/RL that referenced this pull request Aug 30, 2025
Signed-off-by: Zhanda <[email protected]>
Signed-off-by: Zhanda Zhu <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Signed-off-by: Stanislav Kirdey <[email protected]>
soodoshll pushed a commit to soodoshll/RL that referenced this pull request Sep 4, 2025
Signed-off-by: Zhanda <[email protected]>
Signed-off-by: Zhanda Zhu <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Signed-off-by: Qidong Su <[email protected]>
PrinsYin pushed a commit to PrinsYin/RL that referenced this pull request Nov 30, 2025
Signed-off-by: Zhanda <[email protected]>
Signed-off-by: Zhanda Zhu <[email protected]>
Co-authored-by: Shang Wang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix and test accuracies related to temperature and vllm final returned logprob GRPO (AceReason, Deepscaler, Qwen 2.5) accuracy debug

5 participants