diff --git a/docs/algo/grpo_guard.md b/docs/algo/grpo_guard.md new file mode 100644 index 00000000..bca0ec1c --- /dev/null +++ b/docs/algo/grpo_guard.md @@ -0,0 +1,77 @@ +# GRPO-Guard + +Last updated: 05/08/2026. + +GRPO-Guard ([paper](https://arxiv.org/abs/2510.22319)) is an extension of +[Flow-GRPO](flowgrpo.md) that stabilizes the importance-ratio estimate used in +the policy loss. The standard Flow-GRPO ratio +$\rho = \exp(\log p_\theta - \log p_{\text{old}})$ can become numerically +unbalanced when only a single Monte-Carlo noise sample $z$ is used per +denoising step, causing high-variance gradients and aggressive clipping. + +GRPO-Guard adds a **ratio-mean bias** correction that explicitly penalises +drift in the reverse-SDE proposal mean of the current policy relative to the +rollout policy, and rescales the per-step loss by $1 / (\sqrt{-dt})^2$ so the +gradient magnitude is consistent across denoising steps. + +## Algorithm + +For step $t$ with proposal mean $\mu_\theta(x_t)$ from the current policy and +$\mu_{\text{old}}(x_t)$ from the rollout policy, SDE noise scale +$\sigma_t = \mathrm{std\\_dev\\_t}$, and $\sqrt{-dt}$: + +$$ +b_t = \frac{\lVert \mu_\theta - \mu_{\text{old}} \rVert_{\text{mean}}^2} + {2 (\sqrt{-dt}\, \sigma_t)^2} +$$ + +$$ +\rho_t = \exp\big((\log p_\theta - \log p_{\text{old}} + b_t) \cdot + (\sqrt{-dt}\, \sigma_t)\big) +$$ + +$$ +\mathcal{L}^{\text{guard}}_t = + \frac{1}{(\sqrt{-dt})^2}\; + \mathbb{E}\big[\max(-A_t \rho_t,\ -A_t \mathrm{clip}(\rho_t, 1-\epsilon, 1+\epsilon))\big] +$$ + +The squared-norm in $b_t$ is averaged over the channel and spatial dimensions +of the latent (see `compute_diffusion_loss_grpo_guard` in +[verl_omni/trainer/diffusion/diffusion_algos.py](../../verl_omni/trainer/diffusion/diffusion_algos.py)). + +## Configuration + +GRPO-Guard reuses the entire Flow-GRPO training stack — only the actor loss +mode changes. Refer to [Flow-GRPO](flowgrpo.md) for advantage estimator, +rollout, sampling, batch-size, and reward configuration. + +To enable GRPO-Guard: + +- `actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard` +- `actor_rollout_ref.rollout.algo.sde_type=sde` + +A typical small clip ratio works well with the additional bias term: + +- `actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6` + +KL regularisation against a frozen reference policy still works the same way +as Flow-GRPO (`actor_rollout_ref.actor.use_kl_loss=True`, +`actor_rollout_ref.actor.kl_loss_coef=...`). + +## Example script + +A 4-card collocated training script is provided: + +```bash +bash examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh +``` + +It reuses the Flow-GRPO Qwen-Image OCR setup and only flips the actor loss +mode, the clip ratio, and the experiment name. Dataset and model preparation +follow the same instructions as the [Flow-GRPO quick-start](../start/flowgrpo_quickstart.md). + +## References + +- [Flow-GRPO: Online policy gradient RL for flow matching models](https://arxiv.org/abs/2505.05470) +- [GRPO-Guard: ratio-bias regularisation for diffusion-model RL](https://arxiv.org/abs/2510.22319) diff --git a/docs/index.md b/docs/index.md index 62058338..461d6c62 100644 --- a/docs/index.md +++ b/docs/index.md @@ -34,6 +34,7 @@ start/metrics.md :caption: Algorithms algo/flowgrpo.md +algo/grpo_guard.md algo/mixgrpo.md algo/performance.md ``` diff --git a/docs/start/metrics.md b/docs/start/metrics.md index 7cb718c7..8393fd72 100644 --- a/docs/start/metrics.md +++ b/docs/start/metrics.md @@ -1,9 +1,9 @@ (metrics)= # Diffusion Training Metrics -Last updated: 04/23/2026 +Last updated: 05/08/2026 -The table below describes metrics specific to diffusion FlowGRPO training, logged each step to your configured backend (console / W&B). +The table below describes metrics specific to diffusion FlowGRPO / GRPO-Guard training, logged each step to your configured backend (console / W&B). | Metric | Definition | Interpretation | |--------|------------|----------------| @@ -11,6 +11,8 @@ The table below describes metrics specific to diffusion FlowGRPO training, logge | std_mean | $\frac{1}{B}\sum\limits_{i=1}^{B} \sigma_i$ | Tracks average reward diversity across the batch. A declining trend is an early warning of saturation, typically visible before zero_std_ratio spikes. | | pg_clipfrac_higher | $\hat{P}(r > 1 + \varepsilon)$ | The policy is reinforcing high-advantage denoising steps beyond the clip threshold. pg_clipfrac_higher $\gg$ pg_clipfrac_lower signals upward-dominant learning and can guide tuning of the clip ratio or learning rate. | | pg_clipfrac_lower | $\hat{P}(r < 1 - \varepsilon)$ | The policy is suppressing low-advantage denoising steps beyond the clip threshold. Asymmetry between higher and lower clipfrac reveals the dominant learning direction. | +| ratio_mean | $\mathbb{E}[\rho_t]$ | Mean importance ratio across the batch. Should stay close to 1; persistent drift indicates the current policy is diverging from the rollout policy. | +| ratio_std | $\mathrm{Std}(\rho_t)$ | Spread of the importance ratio. High values signal high-variance gradient updates and may indicate the clip ratio or learning rate is too large. | | timing_per_image_ms | Latency (ms/image) per stage | Covers rollout, reference log-prob, old log-prob, advantage computation, and actor update; identifies which stage dominates step time and where to focus optimization effort. | | throughput | $\dfrac{B \times n}{t_\mathrm{step} \times N}$ (images / GPU / s) | Overall training throughput. Use alongside timing_per_image_ms to evaluate scaling efficiency and detect regressions across runs. | @@ -19,7 +21,8 @@ The table below describes metrics specific to diffusion FlowGRPO training, logge - $B$ — number of prompts per training batch - $n$ — number of images generated per prompt - $\sigma_i$ — reward standard deviation within group $i$ -- $r$ — probability ratio $\pi_\theta / \pi_{\theta_\mathrm{old}}$ per (image, denoising-timestep) pair +- $\rho_t$ — importance ratio $\pi_\theta / \pi_{\theta_\mathrm{old}}$ per (image, denoising-timestep) pair +- $r$ — shorthand for $\rho_t$ in clipping expressions - $\varepsilon$ — clip ratio - $N$ — number of GPUs - $t_\mathrm{step}$ — wall-clock time per training step diff --git a/examples/grpoguard_trainer/README.md b/examples/grpoguard_trainer/README.md new file mode 100644 index 00000000..abf67e91 --- /dev/null +++ b/examples/grpoguard_trainer/README.md @@ -0,0 +1,83 @@ +# GRPO-Guard Trainer + +This example shows how to post-train `Qwen-Image` with GRPO-Guard on an OCR-style image generation task. GRPO-Guard extends Flow-GRPO with a reverse-SDE proposal-mean drift correction and per-step loss rescaling for improved training stability. + +For algorithm details, see [`docs/algo/grpo_guard.md`](../../docs/algo/grpo_guard.md). For the base Flow-GRPO setup this example builds on, see [`examples/flowgrpo_trainer/README.md`](../flowgrpo_trainer/README.md). + +## Installation + +Follow the [installation guide](../../docs/start/install.md) to set up the base environment, then install the GRPO-Guard-specific dependency: + +```bash +pip install Levenshtein +``` + +The provided script is configured for a single node with `4` GPUs. + +## Prepare the dataset + +Obtain the raw OCR dataset from the original Flow-GRPO repository: + +- https://github.com/yifan123/flow_grpo/tree/main/dataset/ocr + +Place the raw dataset under `$WORKSPACE/data/ocr` (where `WORKSPACE` defaults to `$HOME`), then preprocess it into parquet files: + +```bash +python3 examples/flowgrpo_trainer/data_process/qwenimage_ocr.py \ + --input_dir $WORKSPACE/data/ocr \ + --output_dir $WORKSPACE/data/ocr +``` + +This produces: + +- `$WORKSPACE/data/ocr/train.parquet` +- `$WORKSPACE/data/ocr/test.parquet` + +## Prepare the models + +**Policy model (Qwen-Image):** the script uses the Hugging Face Hub ID `Qwen/Qwen-Image` directly — no manual download is required. Hugging Face will cache the weights automatically on first run. To use a local copy instead, edit the `model_name` variable in the script directly. + +**Reward model (Qwen3-VL-8B-Instruct):** the script defaults to the Hugging Face Hub ID `Qwen/Qwen3-VL-8B-Instruct`, so no manual download is required — Hugging Face will cache it automatically on first run. To use a local copy instead, edit the `reward_model_name` variable in the script directly. + +## Run training + +Launch the example from the repository root: + +```bash +bash examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh +``` + +The script runs `python3 -m verl_omni.trainer.diffusion.main_flowgrpo` with: + +- `algorithm.adv_estimator=flow_grpo` +- `actor_rollout_ref.model.path=Qwen/Qwen-Image` +- `actor_rollout_ref.model.lora_rank=64` +- `actor_rollout_ref.model.lora_alpha=128` +- `actor_rollout_ref.rollout.name=vllm_omni` +- `actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard` +- `actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6` +- `actor_rollout_ref.rollout.algo.sde_type=sde` +- `reward.custom_reward_function.name=compute_score_ocr` +- `trainer.n_gpus_per_node=4` + +## Logging + +W&B logging is enabled by default in the example script: + +```bash +export WANDB_API_KEY= +``` + +The script sets: + +```bash +trainer.logger='["console", "wandb"]' +trainer.project_name=grpo_guard +trainer.experiment_name=qwen_image_ocr_lora +``` + +Override these values on the command line if you want to log under a different project or run name. + +### Diffusion-specific metrics + +See the [Metrics Documentation](../../docs/start/metrics.md) for a full description of all diffusion-specific training metrics. diff --git a/examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh b/examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh new file mode 100755 index 00000000..4c60600c --- /dev/null +++ b/examples/grpoguard_trainer/run_qwen_image_ocr_lora.sh @@ -0,0 +1,75 @@ +# Qwen-Image lora RL with GRPO-Guard (https://arxiv.org/abs/2510.22319), vllm_omni rollout +set -x + +# Set WORKSPACE to any writable directory; defaults to $HOME +WORKSPACE=${WORKSPACE:-$HOME} + +ocr_train_path=$WORKSPACE/data/ocr/train.parquet +ocr_test_path=$WORKSPACE/data/ocr/test.parquet + +model_name=Qwen/Qwen-Image +reward_model_name=Qwen/Qwen3-VL-8B-Instruct +reward_function_path=verl_omni/utils/reward_score/genrm_ocr.py + +NUM_GPUS_ACTOR_ROLLOUT_REWARD=4 +ROLLOUT_TP=1 +REWARD_TP=4 + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + + +python3 -m verl_omni.trainer.diffusion.main_flowgrpo \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=32 \ + data.max_prompt_length=256 \ + actor_rollout_ref.model.path=$model_name \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['to_q','to_k','to_v','to_out.0','add_q_proj','add_k_proj','add_v_proj','to_add_out','img_mlp.net.0.proj','img_mlp.net.2','txt_mlp.net.0.proj','txt_mlp.net.2']" \ + actor_rollout_ref.actor.optim.lr=3e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.actor.fsdp_config.param_offload=True \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.diffusion_loss.loss_mode=grpo_guard \ + actor_rollout_ref.actor.diffusion_loss.clip_ratio=2e-6 \ + actor_rollout_ref.actor.diffusion_loss.adv_clip_max=5.0 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.agent.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / ROLLOUT_TP)) \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.pipeline.true_cfg_scale=4.0 \ + actor_rollout_ref.rollout.pipeline.max_sequence_length=256 \ + actor_rollout_ref.rollout.algo.noise_level=1.2 \ + actor_rollout_ref.rollout.algo.sde_type="sde" \ + actor_rollout_ref.rollout.algo.sde_window_size=2 \ + actor_rollout_ref.rollout.algo.sde_window_range="[0,5]" \ + actor_rollout_ref.rollout.val_kwargs.pipeline.num_inference_steps=50 \ + actor_rollout_ref.rollout.val_kwargs.algo.noise_level=0.0 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + reward.num_workers=$((NUM_GPUS_ACTOR_ROLLOUT_REWARD / REWARD_TP)) \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$reward_model_name \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=$REWARD_TP \ + reward.custom_reward_function.path=$reward_function_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=grpo_guard \ + trainer.experiment_name=qwen_image_ocr_lora \ + trainer.log_val_generations=8 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=$NUM_GPUS_ACTOR_ROLLOUT_REWARD \ + trainer.nnodes=1 \ + trainer.save_freq=30 \ + trainer.test_freq=30 \ + trainer.total_epochs=15 \ + trainer.total_training_steps=300 "$@" diff --git a/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py b/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py index 8be8121a..219dd296 100644 --- a/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py +++ b/tests/trainer/diffusion/test_diffusion_core_algos_on_cpu.py @@ -114,3 +114,57 @@ def test_compute_policy_loss_flow_grpo() -> None: assert "actor/pg_clipfrac" in pg_metrics assert "actor/pg_clipfrac_higher" in pg_metrics assert "actor/pg_clipfrac_lower" in pg_metrics + + +def test_compute_policy_loss_grpo_guard() -> None: + from hydra import compose, initialize_config_dir + from verl.utils.config import omega_conf_to_dataclass + + from verl_omni.workers.config.diffusion.actor import FSDPDiffusionActorConfig + + batch_size = 4 + rollout_log_probs = torch.randn((batch_size,), dtype=torch.float32) + current_log_probs = torch.randn((batch_size,), dtype=torch.float32) + advantages = torch.randn((batch_size,), dtype=torch.float32) + old_prev_sample_mean = torch.randn((batch_size, 16, 8, 8), dtype=torch.float32) + prev_sample_mean = old_prev_sample_mean + 0.01 * torch.randn_like(old_prev_sample_mean) + std_dev_t = torch.full((batch_size, 1, 1, 1), 0.5, dtype=torch.float32) + sqrt_dt = torch.full((batch_size,), 0.3, dtype=torch.float32) + + with initialize_config_dir( + config_dir=os.path.abspath("verl_omni/trainer/config/diffusion/actor"), version_base=None + ): + cfg = compose( + config_name="dp_diffusion_actor", + overrides=[ + "strategy=fsdp", + "diffusion_loss.loss_mode=grpo_guard", + "diffusion_loss.clip_ratio=2e-6", + "diffusion_loss.adv_clip_max=5.0", + "ppo_micro_batch_size_per_gpu=8", + ], + ) + actor_config: FSDPDiffusionActorConfig = omega_conf_to_dataclass(cfg) + + pg_loss, pg_metrics = diffusion_algos.compute_diffusion_loss_grpo_guard( + old_log_prob=rollout_log_probs, + log_prob=current_log_probs, + advantages=advantages, + config=actor_config, + old_prev_sample_mean=old_prev_sample_mean, + prev_sample_mean=prev_sample_mean, + std_dev_t=std_dev_t, + sqrt_dt=sqrt_dt, + ) + + assert pg_loss.shape == () + assert isinstance(pg_loss.item(), float) + for key in ( + "actor/ppo_kl", + "actor/pg_clipfrac", + "actor/pg_clipfrac_higher", + "actor/pg_clipfrac_lower", + "actor/ratio_mean", + "actor/ratio_std", + ): + assert key in pg_metrics, key diff --git a/verl_omni/pipelines/model_base.py b/verl_omni/pipelines/model_base.py index 8aebdeda..36455702 100644 --- a/verl_omni/pipelines/model_base.py +++ b/verl_omni/pipelines/model_base.py @@ -157,6 +157,9 @@ def forward_and_sample_previous_step( scheduler_inputs (Optional[TensorDict | dict[str, torch.Tensor]]): the extra inputs for the scheduler, which may contain the latents and timesteps. step (int): the current step in the diffusion process. + + Returns: + tuple: ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)`` """ pass diff --git a/verl_omni/pipelines/qwen_image_flow_grpo/diffusers_training_adapter.py b/verl_omni/pipelines/qwen_image_flow_grpo/diffusers_training_adapter.py index b9627577..8177320e 100644 --- a/verl_omni/pipelines/qwen_image_flow_grpo/diffusers_training_adapter.py +++ b/verl_omni/pipelines/qwen_image_flow_grpo/diffusers_training_adapter.py @@ -213,7 +213,7 @@ def forward_and_sample_previous_step( step (int): Current denoising step index. Returns: - tuple: A 3-tuple of ``(log_prob, prev_sample_mean, std_dev_t)``. + tuple: A 4-tuple of ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)``. """ assert scheduler_inputs is not None latents = scheduler_inputs["all_latents"] @@ -226,7 +226,7 @@ def forward_and_sample_previous_step( neg_noise_pred = module(**negative_model_inputs)[0] noise_pred = apply_true_cfg(noise_pred, neg_noise_pred, true_cfg_scale) - _, log_prob, prev_sample_mean, std_dev_t = scheduler.sample_previous_step( + _, log_prob, prev_sample_mean, std_dev_t, sqrt_dt = scheduler.sample_previous_step( sample=latents[:, step].float(), model_output=noise_pred.float(), timestep=timesteps[:, step], @@ -234,5 +234,6 @@ def forward_and_sample_previous_step( prev_sample=latents[:, step + 1].float(), sde_type=model_config.algo.sde_type, return_logprobs=True, + return_sqrt_dt=True, ) - return log_prob, prev_sample_mean, std_dev_t + return log_prob, prev_sample_mean, std_dev_t, sqrt_dt diff --git a/verl_omni/pipelines/schedulers/flow_match_sde.py b/verl_omni/pipelines/schedulers/flow_match_sde.py index 3f0764e5..dcc3b2a8 100644 --- a/verl_omni/pipelines/schedulers/flow_match_sde.py +++ b/verl_omni/pipelines/schedulers/flow_match_sde.py @@ -155,7 +155,37 @@ def sample_previous_step( prev_sample: Optional[torch.Tensor] = None, sde_type: Literal["cps", "sde"] = "sde", return_logprobs: bool = True, + return_sqrt_dt: bool = False, ): + """ + Run a single SDE / CPS reverse step. + + Args: + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`torch.FloatTensor`, *optional*): + The current discrete timestep in the diffusion chain. When `None`, the internal + `step_index` is used (sequential denoising loop). + generator (`torch.Generator`, *optional*): + A random number generator. + per_token_timesteps (`torch.Tensor`, *optional*): + The timesteps for each token in the sample. Currently not supported. + noise_level (`float`, *optional*, defaults to 0.7): + The noise level used in the SDE. + prev_sample (`torch.FloatTensor`, *optional*): + The sample from the previous timestep. If provided, it is used directly for + log-probability computation instead of being sampled. + sde_type (`str`, *optional*, defaults to "sde"): + The type of SDE to use. Choose between "sde" and "cps". + return_logprobs (`bool`, *optional*, defaults to True): + Whether to return log probabilities of the previous sample. + return_sqrt_dt (`bool`, *optional*, defaults to False): + Whether to additionally return `sqrt(-dt)` as a tensor of shape `(batch_size,)`. + Used by GRPO-Guard to compute the importance-ratio normalization + (see `compute_diffusion_loss_grpo_guard`). + """ assert sde_type in ["sde", "cps"] assert sample.dtype == torch.float32 if prev_sample is not None: @@ -226,4 +256,11 @@ def sample_previous_step( # mean along all but batch dimension log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) if log_prob is not None else None + if return_sqrt_dt: + sqrt_dt = torch.sqrt(-1 * dt) + if sqrt_dt.ndim == 0: + sqrt_dt = sqrt_dt.expand(sample.shape[0]).clone() + else: + sqrt_dt = sqrt_dt.reshape(sqrt_dt.shape[0]) + return prev_sample, log_prob, prev_sample_mean, std_dev_t, sqrt_dt return prev_sample, log_prob, prev_sample_mean, std_dev_t diff --git a/verl_omni/trainer/diffusion/diffusion_algos.py b/verl_omni/trainer/diffusion/diffusion_algos.py index 7e412009..0521703b 100644 --- a/verl_omni/trainer/diffusion/diffusion_algos.py +++ b/verl_omni/trainer/diffusion/diffusion_algos.py @@ -233,12 +233,110 @@ def compute_diffusion_loss_flow_grpo( pg_clipfrac = torch.mean((torch.abs(ratio - 1.0) > loss_cfg.clip_ratio).float()) pg_clipfrac_higher = torch.mean((ratio - 1.0 > loss_cfg.clip_ratio).float()) pg_clipfrac_lower = torch.mean((1.0 - ratio > loss_cfg.clip_ratio).float()) + ratio_mean = ratio.mean() + ratio_std = ratio.std() pg_metrics = { "actor/ppo_kl": ppo_kl.detach().item(), "actor/pg_clipfrac": pg_clipfrac.detach().item(), "actor/pg_clipfrac_higher": pg_clipfrac_higher.detach().item(), "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + "actor/ratio_mean": ratio_mean.detach().item(), + "actor/ratio_std": ratio_std.detach().item(), + } + return pg_loss, pg_metrics + + +@register_diffusion_loss("grpo_guard") +def compute_diffusion_loss_grpo_guard( + old_log_prob: torch.Tensor, + log_prob: torch.Tensor, + advantages: torch.Tensor, + config: Optional[DictConfig | DiffusionActorConfig] = None, + *, + old_prev_sample_mean: Optional[torch.Tensor] = None, + prev_sample_mean: Optional[torch.Tensor] = None, + std_dev_t: Optional[torch.Tensor] = None, + sqrt_dt: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, dict[str, Any]]: + """Compute the GRPO-Guard policy objective. + + GRPO-Guard (https://arxiv.org/abs/2510.22319) augments the standard + Flow-GRPO importance ratio with a "ratio-mean bias" term that explicitly + penalises drift in the reverse-SDE proposal mean of the current policy + relative to the rollout policy. The mean drift is then projected onto the + same scale as ``log_prob - old_log_prob`` via the per-step diffusion + coefficient ``sqrt_dt * sigma_t``, and the final policy loss is rescaled + by ``1 / sqrt_dt**2`` so that gradients have a consistent magnitude across + timesteps. + + Args: + old_log_prob (torch.Tensor): Log-probabilities under the old policy, + shape ``(B,)``. + log_prob (torch.Tensor): Log-probabilities under the current policy, + shape ``(B,)``. + advantages (torch.Tensor): Advantage estimates, shape ``(B,)``. + config: Actor configuration; ``diffusion_loss.clip_ratio`` and + ``diffusion_loss.adv_clip_max`` are read from it. + old_prev_sample_mean (torch.Tensor): Reverse-SDE mean from the rollout + policy, shape ``(B, ...)``. + prev_sample_mean (torch.Tensor): Reverse-SDE mean from the current + policy, shape ``(B, ...)``. + std_dev_t (torch.Tensor): Per-step SDE standard deviation, shape + ``(B, 1, 1, ...)`` or scalar. + sqrt_dt (torch.Tensor): ``sqrt(-dt)`` for the current denoising step, + shape ``(B,)`` or scalar. + """ + assert config is not None + assert isinstance(config, DiffusionActorConfig) + assert old_prev_sample_mean is not None, "GRPO-Guard requires `old_prev_sample_mean`" + assert prev_sample_mean is not None, "GRPO-Guard requires `prev_sample_mean`" + assert std_dev_t is not None, "GRPO-Guard requires `std_dev_t`" + assert sqrt_dt is not None, "GRPO-Guard requires `sqrt_dt`" + + loss_cfg = config.diffusion_loss + advantages = torch.clamp( + advantages, + -loss_cfg.adv_clip_max, + loss_cfg.adv_clip_max, + ) + + sigma_t = std_dev_t.mean() + sqrt_dt_mean = sqrt_dt.mean() + scale = sqrt_dt_mean * sigma_t # shared per-step scalar + + # mean over all non-batch dimensions: (B, ...) -> (B,) + mean_diff_sq = (prev_sample_mean - old_prev_sample_mean).pow(2) + if mean_diff_sq.ndim > 1: + mean_diff_sq = mean_diff_sq.mean(dim=tuple(range(1, mean_diff_sq.ndim))) + ratio_mean_bias = mean_diff_sq / (2 * scale**2) + + log_ratio = log_prob - old_log_prob + ratio = torch.exp((log_ratio + ratio_mean_bias) * scale) + + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, + 1.0 - loss_cfg.clip_ratio, + 1.0 + loss_cfg.clip_ratio, + ) + pg_loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) / (sqrt_dt_mean**2) + + with torch.no_grad(): + ppo_kl = torch.mean(-log_ratio) + pg_clipfrac = torch.mean((torch.abs(ratio - 1.0) > loss_cfg.clip_ratio).float()) + pg_clipfrac_higher = torch.mean((ratio - 1.0 > loss_cfg.clip_ratio).float()) + pg_clipfrac_lower = torch.mean((1.0 - ratio > loss_cfg.clip_ratio).float()) + ratio_mean = ratio.mean() + ratio_std = ratio.std() + + pg_metrics = { + "actor/ppo_kl": ppo_kl.detach().item(), + "actor/pg_clipfrac": pg_clipfrac.detach().item(), + "actor/pg_clipfrac_higher": pg_clipfrac_higher.detach().item(), + "actor/pg_clipfrac_lower": pg_clipfrac_lower.detach().item(), + "actor/ratio_mean": ratio_mean.detach().item(), + "actor/ratio_std": ratio_std.detach().item(), } return pg_loss, pg_metrics diff --git a/verl_omni/trainer/diffusion/ray_diffusion_trainer.py b/verl_omni/trainer/diffusion/ray_diffusion_trainer.py index c51f4ff3..954e3a18 100644 --- a/verl_omni/trainer/diffusion/ray_diffusion_trainer.py +++ b/verl_omni/trainer/diffusion/ray_diffusion_trainer.py @@ -805,7 +805,11 @@ def _compute_old_log_prob(self, batch: DataProto): ) output = self.actor_rollout_wg.compute_log_prob(batch_td) log_probs = tu.get(output, "log_probs") - old_log_prob = tu.get_tensordict({"old_log_probs": log_probs.float()}) + old_log_prob_dict = {"old_log_probs": log_probs.float()} + prev_sample_mean = tu.get(output, "prev_sample_mean") + if prev_sample_mean is not None: + old_log_prob_dict["old_prev_sample_mean"] = prev_sample_mean.float() + old_log_prob = tu.get_tensordict(old_log_prob_dict) return DataProto.from_tensordict(old_log_prob) def _update_actor(self, batch: DataProto) -> DataProto: diff --git a/verl_omni/workers/config/diffusion/actor.py b/verl_omni/workers/config/diffusion/actor.py index 47a7d820..e43417e7 100644 --- a/verl_omni/workers/config/diffusion/actor.py +++ b/verl_omni/workers/config/diffusion/actor.py @@ -39,7 +39,7 @@ class DiffusionLossConfig(BaseConfig): def __post_init__(self): """Validate diffusion loss configuration.""" - valid_modes = ["flow_grpo"] + valid_modes = ["flow_grpo", "grpo_guard"] if self.loss_mode not in valid_modes: raise ValueError(f"Invalid diffusion loss_mode: {self.loss_mode}. Must be one of {valid_modes}") diff --git a/verl_omni/workers/engine/fsdp/diffusers_impl.py b/verl_omni/workers/engine/fsdp/diffusers_impl.py index 55ba8063..d55b6538 100644 --- a/verl_omni/workers/engine/fsdp/diffusers_impl.py +++ b/verl_omni/workers/engine/fsdp/diffusers_impl.py @@ -612,11 +612,12 @@ def prepare_model_inputs(self, micro_batch: TensorDict, step: int): ) def prepare_model_outputs(self, output, micro_batch: TensorDict): - log_prob, prev_sample_mean, std_dev_t = output + log_prob, prev_sample_mean, std_dev_t, sqrt_dt = output return { "log_probs": log_prob, "prev_sample_mean": prev_sample_mean, "std_dev_t": std_dev_t, + "sqrt_dt": sqrt_dt, } def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, step): @@ -653,6 +654,9 @@ def forward_step(self, micro_batch: TensorDict, loss_function, forward_only, ste if micro_batch.get("ref_prev_sample_mean", None) is not None: data["ref_prev_sample_mean"] = micro_batch["ref_prev_sample_mean"][:, step] + if micro_batch.get("old_prev_sample_mean", None) is not None: + data["old_prev_sample_mean"] = micro_batch["old_prev_sample_mean"][:, step] + loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) else: assert forward_only, "forward_only must be True when loss_function is None" diff --git a/verl_omni/workers/utils/losses.py b/verl_omni/workers/utils/losses.py index 6edbb5da..b96354ce 100644 --- a/verl_omni/workers/utils/losses.py +++ b/verl_omni/workers/utils/losses.py @@ -36,12 +36,22 @@ def diffusion_loss(config: DiffusionActorConfig, model_output, data: TensorDict, loss_mode = config.diffusion_loss.get("loss_mode", "flow_grpo") policy_loss_fn = get_diffusion_loss_fn(loss_mode) - pg_loss, pg_metrics = policy_loss_fn( + policy_loss_kwargs = dict( old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, config=config, ) + if loss_mode == "grpo_guard": + # GRPO-Guard requires the rollout-time SDE proposal mean and the per-step + # diffusion coefficient terms; pass them through alongside the standard inputs. + policy_loss_kwargs.update( + old_prev_sample_mean=data["old_prev_sample_mean"], + prev_sample_mean=model_output["prev_sample_mean"], + std_dev_t=model_output["std_dev_t"], + sqrt_dt=model_output["sqrt_dt"], + ) + pg_loss, pg_metrics = policy_loss_fn(**policy_loss_kwargs) pg_metrics = Metric.from_dict(pg_metrics, aggregation=AggregationType.MEAN)