Skip to content

SP GRPO support + batch SP fixes#2643

Merged
djsaunde merged 54 commits into
mainfrom
sp-rl-v4
May 12, 2025
Merged

SP GRPO support + batch SP fixes#2643
djsaunde merged 54 commits into
mainfrom
sp-rl-v4

Conversation

@djsaunde

@djsaunde djsaunde commented May 7, 2025

Copy link
Copy Markdown
Collaborator

Description

This PR implements support for sequence parallelism (SP) for the GRPO trainer. This includes a custom sampler, similar to the one implemented in TRL, but with logic for replicating samples across processes in the same SP group.

To make this happen, we need to override a lot more code from the GRPO trainer (e.g., the __init__ function). This could be greatly improved by refactoring the trainer upstream to be a lot more modular and therefore more easily extensible; most of the code we're taking from the superclass is unchanged.

We were also able to fix the issue in the pad_to_sequence_len: false case where, for specific input lengths, the ring_flash_attn batch ring attention function experienced extremely large or inf gradient norms, by calling torch.compile on the function 🤔

  • some small changes / refactors.

Motivation and Context

GRPO for sufficiently advanced tasks will likely require longer sequence lengths than we can fit on on a single GPU. Hence, let's add SP support.

Follow-ups:

  • Add zigzag, stripe batch ring attn adapters + data splitting, gathering logic
  • Upstream base GRPO trainer refactor (?)

How has this been tested?

  • Pytests (need additional coverage)
  • Manual testing of:
    • SP + SFT (pad_to_seq_len: false vs. pad_to_seq_len: true, sample_packing: false vs. sample_packing: true, etc.)
    • SP + GRPO

Screenshots (if appropriate)

Example GRPO + SP training run:

image

Note that the config differed a fair bit from the one in our blog post, so they're not directly comparable.

Types of changes

Social Handles (Optional)

@djsaunde djsaunde self-assigned this May 7, 2025
@github-actions

github-actions Bot commented May 7, 2025

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot temporarily deployed to preview May 7, 2025 03:09 Inactive
@codecov

codecov Bot commented May 7, 2025

Copy link
Copy Markdown

@github-actions github-actions Bot temporarily deployed to preview May 7, 2025 03:28 Inactive
Comment thread src/axolotl/core/trainers/grpo/trainer.py
Comment thread src/axolotl/train.py Outdated
Comment thread src/axolotl/utils/ctx_managers/sequence_parallel.py
@winglian winglian added this to the Axolotl v0.10.0 milestone May 7, 2025
@github-actions github-actions Bot temporarily deployed to preview May 7, 2025 20:20 Inactive
@github-actions github-actions Bot temporarily deployed to preview May 7, 2025 20:38 Inactive
@github-actions github-actions Bot temporarily deployed to preview May 7, 2025 21:06 Inactive
@github-actions github-actions Bot temporarily deployed to preview May 8, 2025 02:24 Inactive
Comment thread src/axolotl/common/datasets.py
Comment thread src/axolotl/core/trainer_builder.py
Comment thread src/axolotl/core/trainers/grpo/sampler.py
Comment thread src/axolotl/core/trainers/grpo/trainer.py
@github-actions github-actions Bot temporarily deployed to preview May 8, 2025 21:35 Inactive
Comment thread src/axolotl/monkeypatch/attention/ring_attn/adapters/batch.py
Comment thread src/axolotl/train.py Outdated
Comment thread src/axolotl/train.py Outdated
Comment thread src/axolotl/utils/ctx_managers/sequence_parallel.py Outdated
Comment thread src/axolotl/utils/ctx_managers/sequence_parallel.py Outdated
Comment thread src/axolotl/utils/ctx_managers/sequence_parallel.py Outdated
Comment thread src/axolotl/utils/schemas/config.py Outdated

@salmanmohammadi salmanmohammadi left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really really nice. A few nits, and we need to sync the GRPOTrainer changes with upstream, but this looks good to me.

@djsaunde

djsaunde commented May 9, 2025

Copy link
Copy Markdown
Collaborator Author

Note: GRPO + SP + Liger results in exploding losses. I suggest we merge this now (once tests pass) and follow up with a fix for this case.

@djsaunde djsaunde merged commit 80304c2 into main May 12, 2025
9 of 12 checks passed
@djsaunde djsaunde deleted the sp-rl-v4 branch May 12, 2025 21:52
@github-actions github-actions Bot temporarily deployed to preview May 12, 2025 21:54 Inactive
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants