[algo] feat: add GRPO-Guard support for Qwen-Image training#48
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the GRPO-Guard algorithm, an extension of Flow-GRPO designed to stabilize importance-ratio estimates in policy loss. The changes include the core implementation of the grpo_guard loss function, updates to the diffusion scheduler and training adapters to support the required sqrt_dt and proposal mean drift terms, and the addition of comprehensive documentation, example scripts, and unit tests. I have no feedback to provide as there were no review comments to evaluate.
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: Cheung Ka Wai <zhtmike@gmail.com>
Signed-off-by: Cheung Ka Wai <zhtmike@gmail.com>
|
@gemini review |
|
why critic reward mean and val reward have a quick drop at ~125 step? |
There was a problem hiding this comment.
Code Review
This pull request implements the GRPO-Guard algorithm, an extension of Flow-GRPO designed to stabilize importance-ratio estimates in diffusion-model RL. The implementation includes the core loss function in diffusion_algos.py, updates to the training pipeline and schedulers to handle additional parameters like sqrt_dt and old_prev_sample_mean, and the addition of new metrics, documentation, and a Qwen-Image OCR training example. Review feedback suggests refactoring the diffusion_loss utility to reduce tight coupling with specific algorithms by passing arguments more generically.
| if loss_mode == "grpo_guard": | ||
| # GRPO-Guard requires the rollout-time SDE proposal mean and the per-step | ||
| # diffusion coefficient terms; pass them through alongside the standard inputs. | ||
| policy_loss_kwargs.update( | ||
| old_prev_sample_mean=data["old_prev_sample_mean"], | ||
| prev_sample_mean=model_output["prev_sample_mean"], | ||
| std_dev_t=model_output["std_dev_t"], | ||
| sqrt_dt=model_output["sqrt_dt"], | ||
| ) |
There was a problem hiding this comment.
The explicit check for grpo_guard to pass extra arguments makes this utility function tightly coupled with specific algorithm implementations. It would be more maintainable to pass all available keys from data and model_output as keyword arguments to the registered loss function, allowing the registry to handle the signature matching.
There was a problem hiding this comment.
Agreed, it has been noted. We will refactor it once the algorithm becomes more complex.
I think it is because all the rewards at the end of the training are almost We can use: 1. a large PPO batch size (more GPUs) to provide a effective reward signal; 2. a harder reward (not so easy to be saturated); 3. adding KL will help with these. |
| advantages=advantages, | ||
| config=config, | ||
| ) | ||
| if loss_mode == "grpo_guard": |
There was a problem hiding this comment.
is loss_mode the same as actor_rollout_ref.model.algorithm?
There was a problem hiding this comment.
actor_rollout_ref.model.algorithm? -> extract the registered components (trainer side and rollout side).
loss_mode -> pick the right loss.
Here, we use algorithm="flowgrpo" to extract the components; loss_mode="grpo_guard" to select grpo guard loss
| data.val_files=$ocr_test_path \ | ||
| data.train_batch_size=32 \ | ||
| data.max_prompt_length=256 \ | ||
| actor_rollout_ref.model.path=$model_name \ |
There was a problem hiding this comment.
should we set actor_rollout_ref.model.algorithm to grpo_guard?
There was a problem hiding this comment.
grpo-guard is improved based on flowgrpo. The only difference is the loss. Here we just reuse the components from flowgrpo
noted. Thanks. |
Conflicts: - examples/flowgrpo_trainer/README.md: kept upstream's new Ulysses-SP and full-weight Qwen-Image variant blurbs together with our BAGEL recipe section. Additional fix: - verl_omni/pipelines/bagel_flow_grpo/diffusers_training_adapter.py: ``forward_and_sample_previous_step`` now returns the new 4-tuple ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)`` to match the GRPO-Guard plumbing introduced upstream in verl-project#48 (BAGEL still trains with ``loss_mode=flow_grpo`` so ``sqrt_dt`` is unused, but the engine layer now unpacks 4-tuples unconditionally). Co-authored-by: GitHub Copilot Signed-off-by: princepride <wangzhipeng628@gmail.com>
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,vllm_omni,rollout,trainer,ci,training_utils,recipe,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,diffusion,omni,tests,docker,like[diffusion, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][diffusion, fsdp] feat: new rollout schedulerTest
The

pg_clipfrac_lowerandpg_clipfrac_highernow symmetric as expectedcritic

validation score

API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always