Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions docs/algo/grpo_guard.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# GRPO-Guard

Last updated: 05/08/2026.

GRPO-Guard ([paper](https://arxiv.org/abs/2510.22319)) is an extension of
[Flow-GRPO](flowgrpo.md) that stabilizes the importance-ratio estimate used in
the policy loss. The standard Flow-GRPO ratio
$\rho = \exp(\log p_\theta - \log p_{\text{old}})$ can become numerically
unbalanced when only a single Monte-Carlo noise sample $z$ is used per
denoising step, causing high-variance gradients and aggressive clipping.

GRPO-Guard adds a **ratio-mean bias** correction that explicitly penalises
drift in the reverse-SDE proposal mean of the current policy relative to the
rollout policy, and rescales the per-step loss by $1 / (\sqrt{-dt})^2$ so the
gradient magnitude is consistent across denoising steps.

## Algorithm

For step $t$ with proposal mean $\mu_\theta(x_t)$ from the current policy and
$\mu_{\text{old}}(x_t)$ from the rollout policy, SDE noise scale
$\sigma_t = \mathrm{std\\_dev\\_t}$, and $\sqrt{-dt}$:

$$
b_t = \frac{\lVert \mu_\theta - \mu_{\text{old}} \rVert_{\text{mean}}^2}
{2 (\sqrt{-dt}\, \sigma_t)^2}
$$

$$
\rho_t = \exp\big((\log p_\theta - \log p_{\text{old}} + b_t) \cdot
(\sqrt{-dt}\, \sigma_t)\big)
$$

$$
\mathcal{L}^{\text{guard}}_t =
\frac{1}{(\sqrt{-dt})^2}\;
\mathbb{E}\big[\max(-A_t \rho_t,\ -A_t \mathrm{clip}(\rho_t, 1-\epsilon, 1+\epsilon))\big]
$$

The squared-norm in $b_t$ is averaged over the channel and spatial dimensions
of the latent (see `compute_diffusion_loss_grpo_guard` in
[verl_omni/trainer/diffusion/diffusion_algos.py](../../verl_omni/trainer/diffusion/diffusion_algos.py)).

## Configuration

GRPO-Guard reuses the entire Flow-GRPO training stack — only the actor loss
mode changes. Refer to [Flow-GRPO](flowgrpo.md) for advantage estimator,
rollout, sampling, batch-size, and reward configuration.

To enable GRPO-Guard:

- `actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard`
- `actor_rollout_ref.rollout.algo.sde_type=sde`

A typical small clip ratio works well with the additional bias term:

- `actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6`

KL regularisation against a frozen reference policy still works the same way
as Flow-GRPO (`actor_rollout_ref.actor.use_kl_loss=True`,
`actor_rollout_ref.actor.kl_loss_coef=...`).

## Example script

A 4-card collocated training script is provided:

```bash
bash examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh
```

It reuses the Flow-GRPO Qwen-Image OCR setup and only flips the actor loss
mode, the clip ratio, and the experiment name. Dataset and model preparation
follow the same instructions as the [Flow-GRPO quick-start](../start/flowgrpo_quickstart.md).

## References

- [Flow-GRPO: Online policy gradient RL for flow matching models](https://arxiv.org/abs/2505.05470)
- [GRPO-Guard: ratio-bias regularisation for diffusion-model RL](https://arxiv.org/abs/2510.22319)
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ start/metrics.md
:caption: Algorithms

algo/flowgrpo.md
algo/grpo_guard.md
algo/mixgrpo.md
algo/performance.md
```
Expand Down
9 changes: 6 additions & 3 deletions docs/start/metrics.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
(metrics)=
# Diffusion Training Metrics

Last updated: 04/23/2026
Last updated: 05/08/2026

The table below describes metrics specific to diffusion FlowGRPO training, logged each step to your configured backend (console / W&B).
The table below describes metrics specific to diffusion FlowGRPO / GRPO-Guard training, logged each step to your configured backend (console / W&B).

| Metric | Definition | Interpretation |
|--------|------------|----------------|
| zero_std_ratio | $\frac{1}{B}\lvert\{i : \sigma_i = 0\}\rvert$ | GRPO derives its learning signal from relative rewards within a group; $\sigma_i = 0$ means group $i$ contributes no gradient regardless of absolute reward. A persistently high value (e.g. $> 0.5$) indicates reward saturation or poorly calibrated task difficulty. |
| std_mean | $\frac{1}{B}\sum\limits_{i=1}^{B} \sigma_i$ | Tracks average reward diversity across the batch. A declining trend is an early warning of saturation, typically visible before zero_std_ratio spikes. |
| pg_clipfrac_higher | $\hat{P}(r > 1 + \varepsilon)$ | The policy is reinforcing high-advantage denoising steps beyond the clip threshold. pg_clipfrac_higher $\gg$ pg_clipfrac_lower signals upward-dominant learning and can guide tuning of the clip ratio or learning rate. |
| pg_clipfrac_lower | $\hat{P}(r < 1 - \varepsilon)$ | The policy is suppressing low-advantage denoising steps beyond the clip threshold. Asymmetry between higher and lower clipfrac reveals the dominant learning direction. |
| ratio_mean | $\mathbb{E}[\rho_t]$ | Mean importance ratio across the batch. Should stay close to 1; persistent drift indicates the current policy is diverging from the rollout policy. |
| ratio_std | $\mathrm{Std}(\rho_t)$ | Spread of the importance ratio. High values signal high-variance gradient updates and may indicate the clip ratio or learning rate is too large. |
| timing_per_image_ms | Latency (ms/image) per stage | Covers rollout, reference log-prob, old log-prob, advantage computation, and actor update; identifies which stage dominates step time and where to focus optimization effort. |
| throughput | $\dfrac{B \times n}{t_\mathrm{step} \times N}$ (images / GPU / s) | Overall training throughput. Use alongside timing_per_image_ms to evaluate scaling efficiency and detect regressions across runs. |

Expand All @@ -19,7 +21,8 @@ The table below describes metrics specific to diffusion FlowGRPO training, logge
- $B$ — number of prompts per training batch
- $n$ — number of images generated per prompt
- $\sigma_i$ — reward standard deviation within group $i$
- $r$ — probability ratio $\pi_\theta / \pi_{\theta_\mathrm{old}}$ per (image, denoising-timestep) pair
- $\rho_t$ — importance ratio $\pi_\theta / \pi_{\theta_\mathrm{old}}$ per (image, denoising-timestep) pair
- $r$ — shorthand for $\rho_t$ in clipping expressions
- $\varepsilon$ — clip ratio
- $N$ — number of GPUs
- $t_\mathrm{step}$ — wall-clock time per training step
83 changes: 83 additions & 0 deletions examples/grpoguard_trainer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# GRPO-Guard Trainer

This example shows how to post-train `Qwen-Image` with GRPO-Guard on an OCR-style image generation task. GRPO-Guard extends Flow-GRPO with a reverse-SDE proposal-mean drift correction and per-step loss rescaling for improved training stability.

For algorithm details, see [`docs/algo/grpo_guard.md`](../../docs/algo/grpo_guard.md). For the base Flow-GRPO setup this example builds on, see [`examples/flowgrpo_trainer/README.md`](../flowgrpo_trainer/README.md).

## Installation

Follow the [installation guide](../../docs/start/install.md) to set up the base environment, then install the GRPO-Guard-specific dependency:

```bash
pip install Levenshtein
```

The provided script is configured for a single node with `4` GPUs.

## Prepare the dataset

Obtain the raw OCR dataset from the original Flow-GRPO repository:

- https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr

Place the raw dataset under `$WORKSPACE/data/ocr` (where `WORKSPACE` defaults to `$HOME`), then preprocess it into parquet files:

```bash
python3 examples/flowgrpo_trainer/data_process/qwenimage_ocr.py \
--input_dir $WORKSPACE/data/ocr \
--output_dir $WORKSPACE/data/ocr
```

This produces:

- `$WORKSPACE/data/ocr/train.parquet`
- `$WORKSPACE/data/ocr/test.parquet`

## Prepare the models

**Policy model (Qwen-Image):** the script uses the Hugging Face Hub ID `Qwen/Qwen-Image` directly — no manual download is required. Hugging Face will cache the weights automatically on first run. To use a local copy instead, edit the `model_name` variable in the script directly.

**Reward model (Qwen3-VL-8B-Instruct):** the script defaults to the Hugging Face Hub ID `Qwen/Qwen3-VL-8B-Instruct`, so no manual download is required — Hugging Face will cache it automatically on first run. To use a local copy instead, edit the `reward_model_name` variable in the script directly.

## Run training

Launch the example from the repository root:

```bash
bash examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh
```

The script runs `python3 -m verl_omni.trainer.diffusion.main_flowgrpo` with:

- `algorithm.adv_estimator=flow_grpo`
- `actor_rollout_ref.model.path=Qwen/Qwen-Image`
- `actor_rollout_ref.model.lora_rank=64`
- `actor_rollout_ref.model.lora_alpha=128`
- `actor_rollout_ref.rollout.name=vllm_omni`
- `actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard`
- `actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6`
- `actor_rollout_ref.rollout.algo.sde_type=sde`
- `reward.custom_reward_function.name=compute_score_ocr`
- `trainer.n_gpus_per_node=4`

## Logging

W&B logging is enabled by default in the example script:

```bash
export WANDB_API_KEY=<your_wandb_api_key>
```

The script sets:

```bash
trainer.logger='["console", "wandb"]'
trainer.project_name=grpo_guard
trainer.experiment_name=qwen_image_ocr_lora
```

Override these values on the command line if you want to log under a different project or run name.

### Diffusion-specific metrics

See the [Metrics Documentation](../../docs/start/metrics.md) for a full description of all diffusion-specific training metrics.
75 changes: 75 additions & 0 deletions examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Qwen-Image lora RL with GRPO-Guard (https://arxiv.org/abs/2510.22319), vllm_omni rollout
set -x

# Set WORKSPACE to any writable directory; defaults to $HOME
WORKSPACE=${WORKSPACE:-$HOME}

ocr_train_path=$WORKSPACE/data/ocr/train.parquet
ocr_test_path=$WORKSPACE/data/ocr/test.parquet

model_name=Qwen/Qwen-Image
reward_model_name=Qwen/Qwen3-VL-8B-Instruct
reward_function_path=verl_omni/utils/reward_score/genrm_ocr.py

NUM_GPUS_ACTOR_ROLLOUT_REWARD=4
ROLLOUT_TP=1
REWARD_TP=4

ENGINE=vllm_omni
REWARD_ENGINE=vllm


python3 -m verl_omni.trainer.diffusion.main_flowgrpo \
algorithm.adv_estimator=flow_grpo \
data.train_files=$ocr_train_path \
data.val_files=$ocr_test_path \
data.train_batch_size=32 \
data.max_prompt_length=256 \
actor_rollout_ref.model.path=$model_name \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we set actor_rollout_ref.model.algorithm to grpo_guard?

Copy link
Copy Markdown
Collaborator Author

@zhtmike zhtmike May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grpo-guard is improved based on flowgrpo. The only difference is the loss. Here we just reuse the components from flowgrpo

actor_rollout_ref.model.lora_rank=64 \
actor_rollout_ref.model.lora_alpha=128 \
actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \
actor_rollout_ref.actor.optim.lr=3e-4 \
actor_rollout_ref.actor.optim.weight_decay=0.0001 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard \
actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6 \
actor_rollout_ref.actor.diffusion_loss.adv_clip_max=5.0 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.name=$ENGINE \
actor_rollout_ref.rollout.n=16 \
actor_rollout_ref.rollout.agent.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / ROLLOUT_TP)) \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.pipeline.true_cfg_scale=4.0 \
actor_rollout_ref.rollout.pipeline.max_sequence_length=256 \
actor_rollout_ref.rollout.algo.noise_level=1.2 \
actor_rollout_ref.rollout.algo.sde_type="sde" \
actor_rollout_ref.rollout.algo.sde_window_size=2 \
actor_rollout_ref.rollout.algo.sde_window_range="[0,5]" \
actor_rollout_ref.rollout.val_kwargs.pipeline.num_inference_steps=50 \
actor_rollout_ref.rollout.val_kwargs.algo.noise_level=0.0 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
reward.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / REWARD_TP)) \
reward.reward_model.enable=True \
reward.reward_model.model_path=$reward_model_name \
reward.reward_model.rollout.name=$REWARD_ENGINE \
reward.reward_model.rollout.tensor_model_parallel_size=$REWARD_TP \
reward.custom_reward_function.path=$reward_function_path \
reward.custom_reward_function.name=compute_score_ocr \
trainer.logger='["console", "wandb"]' \
trainer.project_name=grpo_guard \
trainer.experiment_name=qwen_image_ocr_lora \
trainer.log_val_generations=8 \
trainer.val_before_train=False \
trainer.n_gpus_per_node=$NUM_GPUS_ACTOR_ROLLOUT_REWARD \
trainer.nnodes=1 \
trainer.save_freq=30 \
trainer.test_freq=30 \
trainer.total_epochs=15 \
trainer.total_training_steps=300 "$@"
54 changes: 54 additions & 0 deletions tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,57 @@ def test_compute_policy_loss_flow_grpo() -> None:
assert "actor/pg_clipfrac" in pg_metrics
assert "actor/pg_clipfrac_higher" in pg_metrics
assert "actor/pg_clipfrac_lower" in pg_metrics


def test_compute_policy_loss_grpo_guard() -> None:
from hydra import compose, initialize_config_dir
from verl.utils.config import omega_conf_to_dataclass

from verl_omni.workers.config.diffusion.actor import FSDPDiffusionActorConfig

batch_size = 4
rollout_log_probs = torch.randn((batch_size,), dtype=torch.float32)
current_log_probs = torch.randn((batch_size,), dtype=torch.float32)
advantages = torch.randn((batch_size,), dtype=torch.float32)
old_prev_sample_mean = torch.randn((batch_size, 16, 8, 8), dtype=torch.float32)
prev_sample_mean = old_prev_sample_mean + 0.01 * torch.randn_like(old_prev_sample_mean)
std_dev_t = torch.full((batch_size, 1, 1, 1), 0.5, dtype=torch.float32)
sqrt_dt = torch.full((batch_size,), 0.3, dtype=torch.float32)

with initialize_config_dir(
config_dir=os.path.abspath("verl_omni/trainer/config/diffusion/actor"), version_base=None
):
cfg = compose(
config_name="dp_diffusion_actor",
overrides=[
"strategy=fsdp",
"diffusion_loss.loss_mode=grpo_guard",
"diffusion_loss.clip_ratio=2e-6",
"diffusion_loss.adv_clip_max=5.0",
"ppo_micro_batch_size_per_gpu=8",
],
)
actor_config: FSDPDiffusionActorConfig = omega_conf_to_dataclass(cfg)

pg_loss, pg_metrics = diffusion_algos.compute_diffusion_loss_grpo_guard(
old_log_prob=rollout_log_probs,
log_prob=current_log_probs,
advantages=advantages,
config=actor_config,
old_prev_sample_mean=old_prev_sample_mean,
prev_sample_mean=prev_sample_mean,
std_dev_t=std_dev_t,
sqrt_dt=sqrt_dt,
)

assert pg_loss.shape == ()
assert isinstance(pg_loss.item(), float)
for key in (
"actor/ppo_kl",
"actor/pg_clipfrac",
"actor/pg_clipfrac_higher",
"actor/pg_clipfrac_lower",
"actor/ratio_mean",
"actor/ratio_std",
):
assert key in pg_metrics, key
3 changes: 3 additions & 0 deletions verl_omni/pipelines/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def forward_and_sample_previous_step(
scheduler_inputs (Optional[TensorDict | dict[str, torch.Tensor]]): the extra inputs for the scheduler,
which may contain the latents and timesteps.
step (int): the current step in the diffusion process.

Returns:
tuple: ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)``
"""
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def forward_and_sample_previous_step(
step (int): Current denoising step index.

Returns:
tuple: A 3-tuple of ``(log_prob, prev_sample_mean, std_dev_t)``.
tuple: A 4-tuple of ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)``.
"""
assert scheduler_inputs is not None
latents = scheduler_inputs["all_latents"]
Expand All @@ -226,13 +226,14 @@ def forward_and_sample_previous_step(
neg_noise_pred = module(**negative_model_inputs)[0]
noise_pred = apply_true_cfg(noise_pred, neg_noise_pred, true_cfg_scale)

_, log_prob, prev_sample_mean, std_dev_t = scheduler.sample_previous_step(
_, log_prob, prev_sample_mean, std_dev_t, sqrt_dt = scheduler.sample_previous_step(
sample=latents[:, step].float(),
model_output=noise_pred.float(),
timestep=timesteps[:, step],
noise_level=model_config.algo.noise_level,
prev_sample=latents[:, step + 1].float(),
sde_type=model_config.algo.sde_type,
return_logprobs=True,
return_sqrt_dt=True,
)
return log_prob, prev_sample_mean, std_dev_t
return log_prob, prev_sample_mean, std_dev_t, sqrt_dt
Loading
Loading