Skip to content

feat: optimize get logprobs when cp enabled.#528

Merged
terrykong merged 9 commits intomainfrom
joyang/cp_opt
Jul 5, 2025
Merged

feat: optimize get logprobs when cp enabled.#528
terrykong merged 9 commits intomainfrom
joyang/cp_opt

Conversation

@joyang-nv
Copy link
Copy Markdown
Member

@joyang-nv joyang-nv commented Jun 18, 2025

What does this PR do ?

This PR optimize get logprobs when CP is enabled for FSDP2. Issue #549

Issues

In previous PR, the logits were retrieved from sharded one (local tensor shape [b, s / cp_size, v / tp_size]) into full tensor with shape [b, s, v] and passed to loss function.
The key reason was we had to ensure sequence order was correct when get log probs.
This PR allows permuted sequenced to pass to loss function with additional full tensor seq_index which indicates the order of the permuted sequence and allow parallel logprobs computation even mixed with tp enabled.

Test Result

cp8

convergence time cost
image image

tp4cp2

convergence time cost
image image
TP4CP2-0624 - TIMING/TRAIN/POLICY_TRAINING CP8-0624 - TIMING/TRAIN/POLICY_TRAINING LLAMA-3.1-8B-INSTRUCT-CP8-0610 - TIMING/TRAIN/POLICY_TRAINING LLAMA-3.1-8B-INSTRUCT-TP4CP2-0609 - TIMING/TRAIN/POLICY_TRAINING
62.53703141 52.19302438 55.94293963 66.29522389

Average step has saved 3.+ seconds.

@github-actions github-actions bot added the documentation Improvements or additions to documentation label Jun 18, 2025
@github-actions github-actions bot removed the documentation Improvements or additions to documentation label Jun 18, 2025
@joyang-nv joyang-nv force-pushed the joyang/cp_opt branch 3 times, most recently from 34f829d to a068b07 Compare June 25, 2025 07:14
@joyang-nv joyang-nv changed the title Optimize get logprobs when cp enabled. feat: optimize get logprobs when cp enabled. Jun 25, 2025
@joyang-nv joyang-nv added the CI:L1 Run doctests, unit tests, and functional tests label Jun 25, 2025
@joyang-nv joyang-nv requested review from SahilJain314, abukharin-nv and gshennvm and removed request for SahilJain314 June 25, 2025 08:23
@joyang-nv joyang-nv added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jun 25, 2025
@joyang-nv joyang-nv requested a review from terrykong June 25, 2025 16:06
@joyang-nv joyang-nv requested review from terrykong and yuki-97 July 1, 2025 14:42
@joyang-nv joyang-nv added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jul 1, 2025
joyang-nv added 2 commits July 3, 2025 10:43
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jonas yang <joyang@nvidia.com>
joyang-nv added 3 commits July 3, 2025 10:43
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jonas yang <joyang@nvidia.com>
@joyang-nv joyang-nv requested a review from SahilJain314 July 3, 2025 02:43
@joyang-nv joyang-nv marked this pull request as ready for review July 3, 2025 02:48
@joyang-nv joyang-nv added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jul 3, 2025
Copy link
Copy Markdown
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change lgtm, but @SahilJain314 should also give approval since it affects CP in mcore

@joyang-nv joyang-nv added CI:L1 Run doctests, unit tests, and functional tests and removed CI:L1 Run doctests, unit tests, and functional tests labels Jul 3, 2025
@SahilJain314 SahilJain314 enabled auto-merge July 3, 2025 23:28
@SahilJain314 SahilJain314 added this pull request to the merge queue Jul 3, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Jul 4, 2025
@terrykong terrykong added this pull request to the merge queue Jul 5, 2025
Merged via the queue into main with commit b48acf7 Jul 5, 2025
13 of 14 checks passed
@terrykong terrykong deleted the joyang/cp_opt branch July 5, 2025 08:03
therealnaveenkamal pushed a commit to therealnaveenkamal/RL that referenced this pull request Jul 7, 2025
Signed-off-by: Jonas yang <joyang@nvidia.com>
@yuki-97 yuki-97 linked an issue Jul 7, 2025 that may be closed by this pull request
jialei777 pushed a commit to jialei777/nemo-rl that referenced this pull request Jul 23, 2025
Signed-off-by: Jonas yang <joyang@nvidia.com>
Signed-off-by: Jialei Chen <jialeic@google.com>
KiddoZhu pushed a commit that referenced this pull request Jul 28, 2025
Signed-off-by: Jonas yang <joyang@nvidia.com>
FannYYW pushed a commit to xxman-google/NeMo-RL that referenced this pull request Aug 5, 2025
Signed-off-by: Jonas yang <joyang@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L1 Run doctests, unit tests, and functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Parallel get logprobs when CP/TP mixed case for FSDP2.

3 participants