-
Notifications
You must be signed in to change notification settings - Fork 24
[algo] feat: add GRPO-Guard support for Qwen-Image training #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+460
−10
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
e5721ba
add GRPO-Guard Algo
zhtmike 16910a6
fix readme
zhtmike 5a1b2bd
clean test
zhtmike 19619c7
Merge branch 'main' into grpo_guard
zhtmike 5987f95
mv scripts
zhtmike f0b0c67
revert chagne
zhtmike 4a5a308
update script
zhtmike cd96340
fix merge
zhtmike 40c493d
clean comment
zhtmike 65e583b
update metric & documents
zhtmike 0134cd0
update metrics
zhtmike 5276914
update
zhtmike 8ff16f2
Potential fix for pull request finding
zhtmike 4a06bb3
Merge branch 'main' into grpo_guard
zhtmike 83e6fd7
Merge branch 'main' into grpo_guard
zhtmike d52176d
Merge branch 'main' into grpo_guard
zhtmike 3d09b7a
Merge branch 'main' into grpo_guard
zhtmike File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # GRPO-Guard | ||
|
|
||
| Last updated: 05/08/2026. | ||
|
|
||
| GRPO-Guard ([paper](https://arxiv.org/abs/2510.22319)) is an extension of | ||
| [Flow-GRPO](flowgrpo.md) that stabilizes the importance-ratio estimate used in | ||
| the policy loss. The standard Flow-GRPO ratio | ||
| $\rho = \exp(\log p_\theta - \log p_{\text{old}})$ can become numerically | ||
| unbalanced when only a single Monte-Carlo noise sample $z$ is used per | ||
| denoising step, causing high-variance gradients and aggressive clipping. | ||
|
|
||
| GRPO-Guard adds a **ratio-mean bias** correction that explicitly penalises | ||
| drift in the reverse-SDE proposal mean of the current policy relative to the | ||
| rollout policy, and rescales the per-step loss by $1 / (\sqrt{-dt})^2$ so the | ||
| gradient magnitude is consistent across denoising steps. | ||
|
|
||
| ## Algorithm | ||
|
|
||
| For step $t$ with proposal mean $\mu_\theta(x_t)$ from the current policy and | ||
| $\mu_{\text{old}}(x_t)$ from the rollout policy, SDE noise scale | ||
| $\sigma_t = \mathrm{std\\_dev\\_t}$, and $\sqrt{-dt}$: | ||
|
|
||
| $$ | ||
| b_t = \frac{\lVert \mu_\theta - \mu_{\text{old}} \rVert_{\text{mean}}^2} | ||
| {2 (\sqrt{-dt}\, \sigma_t)^2} | ||
| $$ | ||
|
|
||
| $$ | ||
| \rho_t = \exp\big((\log p_\theta - \log p_{\text{old}} + b_t) \cdot | ||
| (\sqrt{-dt}\, \sigma_t)\big) | ||
| $$ | ||
|
|
||
| $$ | ||
| \mathcal{L}^{\text{guard}}_t = | ||
| \frac{1}{(\sqrt{-dt})^2}\; | ||
| \mathbb{E}\big[\max(-A_t \rho_t,\ -A_t \mathrm{clip}(\rho_t, 1-\epsilon, 1+\epsilon))\big] | ||
| $$ | ||
|
|
||
| The squared-norm in $b_t$ is averaged over the channel and spatial dimensions | ||
| of the latent (see `compute_diffusion_loss_grpo_guard` in | ||
| [verl_omni/trainer/diffusion/diffusion_algos.py](../../verl_omni/trainer/diffusion/diffusion_algos.py)). | ||
|
|
||
| ## Configuration | ||
|
|
||
| GRPO-Guard reuses the entire Flow-GRPO training stack — only the actor loss | ||
| mode changes. Refer to [Flow-GRPO](flowgrpo.md) for advantage estimator, | ||
| rollout, sampling, batch-size, and reward configuration. | ||
|
|
||
| To enable GRPO-Guard: | ||
|
|
||
| - `actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard` | ||
| - `actor_rollout_ref.rollout.algo.sde_type=sde` | ||
|
|
||
| A typical small clip ratio works well with the additional bias term: | ||
|
|
||
| - `actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6` | ||
|
|
||
| KL regularisation against a frozen reference policy still works the same way | ||
| as Flow-GRPO (`actor_rollout_ref.actor.use_kl_loss=True`, | ||
| `actor_rollout_ref.actor.kl_loss_coef=...`). | ||
|
|
||
| ## Example script | ||
|
|
||
| A 4-card collocated training script is provided: | ||
|
|
||
| ```bash | ||
| bash examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh | ||
| ``` | ||
|
|
||
| It reuses the Flow-GRPO Qwen-Image OCR setup and only flips the actor loss | ||
| mode, the clip ratio, and the experiment name. Dataset and model preparation | ||
| follow the same instructions as the [Flow-GRPO quick-start](../start/flowgrpo_quickstart.md). | ||
|
|
||
| ## References | ||
|
|
||
| - [Flow-GRPO: Online policy gradient RL for flow matching models](https://arxiv.org/abs/2505.05470) | ||
| - [GRPO-Guard: ratio-bias regularisation for diffusion-model RL](https://arxiv.org/abs/2510.22319) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # GRPO-Guard Trainer | ||
|
|
||
| This example shows how to post-train `Qwen-Image` with GRPO-Guard on an OCR-style image generation task. GRPO-Guard extends Flow-GRPO with a reverse-SDE proposal-mean drift correction and per-step loss rescaling for improved training stability. | ||
|
|
||
| For algorithm details, see [`docs/algo/grpo_guard.md`](../../docs/algo/grpo_guard.md). For the base Flow-GRPO setup this example builds on, see [`examples/flowgrpo_trainer/README.md`](../flowgrpo_trainer/README.md). | ||
|
|
||
| ## Installation | ||
|
|
||
| Follow the [installation guide](../../docs/start/install.md) to set up the base environment, then install the GRPO-Guard-specific dependency: | ||
|
|
||
| ```bash | ||
| pip install Levenshtein | ||
| ``` | ||
|
|
||
| The provided script is configured for a single node with `4` GPUs. | ||
|
|
||
| ## Prepare the dataset | ||
|
|
||
| Obtain the raw OCR dataset from the original Flow-GRPO repository: | ||
|
|
||
| - https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr | ||
|
|
||
| Place the raw dataset under `$WORKSPACE/data/ocr` (where `WORKSPACE` defaults to `$HOME`), then preprocess it into parquet files: | ||
|
|
||
| ```bash | ||
| python3 examples/flowgrpo_trainer/data_process/qwenimage_ocr.py \ | ||
| --input_dir $WORKSPACE/data/ocr \ | ||
| --output_dir $WORKSPACE/data/ocr | ||
| ``` | ||
|
|
||
| This produces: | ||
|
|
||
| - `$WORKSPACE/data/ocr/train.parquet` | ||
| - `$WORKSPACE/data/ocr/test.parquet` | ||
|
|
||
| ## Prepare the models | ||
|
|
||
| **Policy model (Qwen-Image):** the script uses the Hugging Face Hub ID `Qwen/Qwen-Image` directly — no manual download is required. Hugging Face will cache the weights automatically on first run. To use a local copy instead, edit the `model_name` variable in the script directly. | ||
|
|
||
| **Reward model (Qwen3-VL-8B-Instruct):** the script defaults to the Hugging Face Hub ID `Qwen/Qwen3-VL-8B-Instruct`, so no manual download is required — Hugging Face will cache it automatically on first run. To use a local copy instead, edit the `reward_model_name` variable in the script directly. | ||
|
|
||
| ## Run training | ||
|
|
||
| Launch the example from the repository root: | ||
|
|
||
| ```bash | ||
| bash examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh | ||
| ``` | ||
|
|
||
| The script runs `python3 -m verl_omni.trainer.diffusion.main_flowgrpo` with: | ||
|
|
||
| - `algorithm.adv_estimator=flow_grpo` | ||
| - `actor_rollout_ref.model.path=Qwen/Qwen-Image` | ||
| - `actor_rollout_ref.model.lora_rank=64` | ||
| - `actor_rollout_ref.model.lora_alpha=128` | ||
| - `actor_rollout_ref.rollout.name=vllm_omni` | ||
| - `actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard` | ||
| - `actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6` | ||
| - `actor_rollout_ref.rollout.algo.sde_type=sde` | ||
| - `reward.custom_reward_function.name=compute_score_ocr` | ||
| - `trainer.n_gpus_per_node=4` | ||
|
|
||
| ## Logging | ||
|
|
||
| W&B logging is enabled by default in the example script: | ||
|
|
||
| ```bash | ||
| export WANDB_API_KEY=<your_wandb_api_key> | ||
| ``` | ||
|
|
||
| The script sets: | ||
|
|
||
| ```bash | ||
| trainer.logger='["console", "wandb"]' | ||
| trainer.project_name=grpo_guard | ||
| trainer.experiment_name=qwen_image_ocr_lora | ||
| ``` | ||
|
|
||
| Override these values on the command line if you want to log under a different project or run name. | ||
|
|
||
| ### Diffusion-specific metrics | ||
|
|
||
| See the [Metrics Documentation](../../docs/start/metrics.md) for a full description of all diffusion-specific training metrics. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Qwen-Image lora RL with GRPO-Guard (https://arxiv.org/abs/2510.22319), vllm_omni rollout | ||
| set -x | ||
|
|
||
| # Set WORKSPACE to any writable directory; defaults to $HOME | ||
| WORKSPACE=${WORKSPACE:-$HOME} | ||
|
|
||
| ocr_train_path=$WORKSPACE/data/ocr/train.parquet | ||
| ocr_test_path=$WORKSPACE/data/ocr/test.parquet | ||
|
|
||
| model_name=Qwen/Qwen-Image | ||
| reward_model_name=Qwen/Qwen3-VL-8B-Instruct | ||
| reward_function_path=verl_omni/utils/reward_score/genrm_ocr.py | ||
|
|
||
| NUM_GPUS_ACTOR_ROLLOUT_REWARD=4 | ||
| ROLLOUT_TP=1 | ||
| REWARD_TP=4 | ||
|
|
||
| ENGINE=vllm_omni | ||
| REWARD_ENGINE=vllm | ||
|
|
||
|
|
||
| python3 -m verl_omni.trainer.diffusion.main_flowgrpo \ | ||
| algorithm.adv_estimator=flow_grpo \ | ||
| data.train_files=$ocr_train_path \ | ||
| data.val_files=$ocr_test_path \ | ||
| data.train_batch_size=32 \ | ||
| data.max_prompt_length=256 \ | ||
| actor_rollout_ref.model.path=$model_name \ | ||
| actor_rollout_ref.model.lora_rank=64 \ | ||
| actor_rollout_ref.model.lora_alpha=128 \ | ||
| actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ | ||
| actor_rollout_ref.actor.optim.lr=3e-4 \ | ||
| actor_rollout_ref.actor.optim.weight_decay=0.0001 \ | ||
| actor_rollout_ref.actor.ppo_mini_batch_size=16 \ | ||
| actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ | ||
| actor_rollout_ref.actor.fsdp_config.param_offload=True \ | ||
| actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ | ||
| actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ | ||
| actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard \ | ||
| actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6 \ | ||
| actor_rollout_ref.actor.diffusion_loss.adv_clip_max=5.0 \ | ||
| actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ | ||
| actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ | ||
| actor_rollout_ref.rollout.name=$ENGINE \ | ||
| actor_rollout_ref.rollout.n=16 \ | ||
| actor_rollout_ref.rollout.agent.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / ROLLOUT_TP)) \ | ||
| actor_rollout_ref.rollout.load_format=safetensors \ | ||
| actor_rollout_ref.rollout.layered_summon=True \ | ||
| actor_rollout_ref.rollout.pipeline.true_cfg_scale=4.0 \ | ||
| actor_rollout_ref.rollout.pipeline.max_sequence_length=256 \ | ||
| actor_rollout_ref.rollout.algo.noise_level=1.2 \ | ||
| actor_rollout_ref.rollout.algo.sde_type="sde" \ | ||
| actor_rollout_ref.rollout.algo.sde_window_size=2 \ | ||
| actor_rollout_ref.rollout.algo.sde_window_range="[0,5]" \ | ||
| actor_rollout_ref.rollout.val_kwargs.pipeline.num_inference_steps=50 \ | ||
| actor_rollout_ref.rollout.val_kwargs.algo.noise_level=0.0 \ | ||
| actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ | ||
| reward.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / REWARD_TP)) \ | ||
| reward.reward_model.enable=True \ | ||
| reward.reward_model.model_path=$reward_model_name \ | ||
| reward.reward_model.rollout.name=$REWARD_ENGINE \ | ||
| reward.reward_model.rollout.tensor_model_parallel_size=$REWARD_TP \ | ||
| reward.custom_reward_function.path=$reward_function_path \ | ||
| reward.custom_reward_function.name=compute_score_ocr \ | ||
| trainer.logger='["console", "wandb"]' \ | ||
| trainer.project_name=grpo_guard \ | ||
| trainer.experiment_name=qwen_image_ocr_lora \ | ||
| trainer.log_val_generations=8 \ | ||
| trainer.val_before_train=False \ | ||
| trainer.n_gpus_per_node=$NUM_GPUS_ACTOR_ROLLOUT_REWARD \ | ||
| trainer.nnodes=1 \ | ||
| trainer.save_freq=30 \ | ||
| trainer.test_freq=30 \ | ||
| trainer.total_epochs=15 \ | ||
| trainer.total_training_steps=300 "$@" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we set actor_rollout_ref.model.algorithm to grpo_guard?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grpo-guard is improved based on flowgrpo. The only difference is the loss. Here we just reuse the components from flowgrpo