Skip to content

[algo] feat: add GRPO-Guard support for Qwen-Image training#48

Merged
AndyZhou952 merged 17 commits into
verl-project:mainfrom
zhtmike:grpo_guard
May 12, 2026
Merged

[algo] feat: add GRPO-Guard support for Qwen-Image training#48
AndyZhou952 merged 17 commits into
verl-project:mainfrom
zhtmike:grpo_guard

Conversation

@zhtmike
Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike commented Apr 30, 2026

What does this PR do?

  • add GRPO-Guard support for Qwen-Image training

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, vllm_omni, rollout, trainer, ci, training_utils, recipe, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, diffusion, omni, tests, docker
    • If this PR involves multiple modules, separate them with , like [diffusion, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][diffusion, fsdp] feat: new rollout scheduler

Test

The pg_clipfrac_lower and pg_clipfrac_higher now symmetric as expected
螢幕截圖 2026-05-11 上午10 54 04

critic
螢幕截圖 2026-05-11 上午10 56 14

validation score
螢幕截圖 2026-05-11 上午10 56 53

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...

@zhtmike zhtmike changed the title [algo].feat: add GRPO-Guard support for Qwen-Image training [algo] feat: add GRPO-Guard support for Qwen-Image training Apr 30, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the GRPO-Guard algorithm, an extension of Flow-GRPO designed to stabilize importance-ratio estimates in policy loss. The changes include the core implementation of the grpo_guard loss function, updates to the diffusion scheduler and training adapters to support the required sqrt_dt and proposal mean drift terms, and the addition of comprehensive documentation, example scripts, and unit tests. I have no feedback to provide as there were no review comments to evaluate.

@SamitHuang SamitHuang mentioned this pull request May 6, 2026
27 tasks
@zhtmike zhtmike marked this pull request as ready for review May 11, 2026 02:57
@zhtmike zhtmike requested a review from SamitHuang as a code owner May 11, 2026 02:57
@zhtmike zhtmike requested a review from AndyZhou952 May 11, 2026 02:59
Signed-off-by: Cheung Ka Wai <zhtmike@gmail.com>
@SamitHuang
Copy link
Copy Markdown
Collaborator

@gemini review

@SamitHuang
Copy link
Copy Markdown
Collaborator

why critic reward mean and val reward have a quick drop at ~125 step?

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements the GRPO-Guard algorithm, an extension of Flow-GRPO designed to stabilize importance-ratio estimates in diffusion-model RL. The implementation includes the core loss function in diffusion_algos.py, updates to the training pipeline and schedulers to handle additional parameters like sqrt_dt and old_prev_sample_mean, and the addition of new metrics, documentation, and a Qwen-Image OCR training example. Review feedback suggests refactoring the diffusion_loss utility to reduce tight coupling with specific algorithms by passing arguments more generically.

Comment on lines +45 to +53
if loss_mode == "grpo_guard":
# GRPO-Guard requires the rollout-time SDE proposal mean and the per-step
# diffusion coefficient terms; pass them through alongside the standard inputs.
policy_loss_kwargs.update(
old_prev_sample_mean=data["old_prev_sample_mean"],
prev_sample_mean=model_output["prev_sample_mean"],
std_dev_t=model_output["std_dev_t"],
sqrt_dt=model_output["sqrt_dt"],
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The explicit check for grpo_guard to pass extra arguments makes this utility function tightly coupled with specific algorithm implementations. It would be more maintainable to pass all available keys from data and model_output as keyword arguments to the registered loss function, allowing the registry to handle the signature matching.

Copy link
Copy Markdown
Collaborator Author

@zhtmike zhtmike May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it has been noted. We will refactor it once the algorithm becomes more complex.

@zhtmike
Copy link
Copy Markdown
Collaborator Author

zhtmike commented May 11, 2026

why critic reward mean and val reward have a quick drop at ~125 step?

I think it is because all the rewards at the end of the training are almost 1, causing the reward signal to diminish, GRPO std -> 0, and thus making training unstable.

We can use: 1. a large PPO batch size (more GPUs) to provide a effective reward signal; 2. a harder reward (not so easy to be saturated); 3. adding KL will help with these.

advantages=advantages,
config=config,
)
if loss_mode == "grpo_guard":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is loss_mode the same as actor_rollout_ref.model.algorithm?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actor_rollout_ref.model.algorithm? -> extract the registered components (trainer side and rollout side).
loss_mode -> pick the right loss.

Here, we use algorithm="flowgrpo" to extract the components; loss_mode="grpo_guard" to select grpo guard loss

data.val_files=$ocr_test_path \
data.train_batch_size=32 \
data.max_prompt_length=256 \
actor_rollout_ref.model.path=$model_name \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we set actor_rollout_ref.model.algorithm to grpo_guard?

Copy link
Copy Markdown
Collaborator Author

@zhtmike zhtmike May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grpo-guard is improved based on flowgrpo. The only difference is the loss. Here we just reuse the components from flowgrpo

@SamitHuang SamitHuang added the ready-for-ci read for running CI label May 11, 2026
@github-actions github-actions Bot removed the ready-for-ci read for running CI label May 12, 2026
Copy link
Copy Markdown
Collaborator

@AndyZhou952 AndyZhou952 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to update the algo part for clarity later (i.e. highlight the key difference and motivation compared to GRPO, unify notation, etc.).

@zhtmike
Copy link
Copy Markdown
Collaborator Author

zhtmike commented May 12, 2026

Might need to update the algo part for clarity later (i.e. highlight the key difference and motivation compared to GRPO, unify notation, etc.).

noted. Thanks.

@zhtmike zhtmike added the ready-for-ci read for running CI label May 12, 2026
@AndyZhou952 AndyZhou952 merged commit 75227f2 into verl-project:main May 12, 2026
16 of 18 checks passed
princepride added a commit to timzsu/verl-omni that referenced this pull request May 13, 2026
Conflicts:
- examples/flowgrpo_trainer/README.md: kept upstream's new Ulysses-SP and
  full-weight Qwen-Image variant blurbs together with our BAGEL recipe
  section.

Additional fix:
- verl_omni/pipelines/bagel_flow_grpo/diffusers_training_adapter.py:
  ``forward_and_sample_previous_step`` now returns the new 4-tuple
  ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)`` to match the
  GRPO-Guard plumbing introduced upstream in verl-project#48 (BAGEL still trains
  with ``loss_mode=flow_grpo`` so ``sqrt_dt`` is unused, but the engine
  layer now unpacks 4-tuples unconditionally).

Co-authored-by: GitHub Copilot
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready-for-ci read for running CI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants