Skip to content

[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO#83

Draft
SamitHuang wants to merge 1 commit into
verl-project:mainfrom
SamitHuang:feat/diffusion-batched-flowgrpo
Draft

[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO#83
SamitHuang wants to merge 1 commit into
verl-project:mainfrom
SamitHuang:feat/diffusion-batched-flowgrpo

Conversation

@SamitHuang
Copy link
Copy Markdown
Collaborator

Summary

End-to-end FlowGRPO integration for batched (B = N) diffusion rollout, so a single vllm-omni request per prompt replaces the existing rollout.n concurrent B = 1 requests that AsyncOmniDiffusion's ThreadPoolExecutor(max_workers=1) would otherwise serialize. Gated by a single config flag, default on.

This PR supersedes the earlier "expose API only" change in #2: same server-side feature, without the standalone benchmark script, plus the agent-loop wiring, a config knob, an example doc, and a focused test.

How to enable in end-to-end FlowGRPO training

The new config knob is:

actor_rollout_ref.rollout.enable_batched_diffusion  # default: True
  • Default (True): when rollout.n > 1, each prompt group of rollout.n interleaved rows is collapsed into one generate_batched(num_outputs_per_prompt = rollout.n) engine call, i.e. a single B = rollout.n transformer forward.
  • Opt-out (False): the agent loop falls back to the legacy per-sample fan-out (rollout.n concurrent B = 1 requests).

examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh shows the flag inline; the two sibling LoRA scripts (..._lora_sp2.sh, ..._lora_async_reward.sh) automatically pick up the dataclass default.

Changes

  • verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py

    • New Ray-callable generate_batched(num_outputs_per_prompt: int) returning list[DiffusionOutput].
    • New private _generate_engine_call helper shared with generate.
    • New _unpack_batched_result slices the batched OmniRequestOutput (PIL list + batched custom-output tensors: latents / log-probs / prompt-embeds / masks) into per-sample DiffusionOutput objects.
    • generate is now a thin wrapper around _generate_engine_call(num_outputs_per_prompt = 1) → fully backwards compatible.
  • verl_omni/agent_loop/diffusion_agent_loop.py

    • generate_sequences dispatches to _run_agent_loops_batched when _can_use_batched_path is true.
    • _can_use_batched_path activates only when: flag is on, rollout.name == "vllm_omni", rollout_n > 1, len(batch) % rollout_n == 0, and every group of rollout_n rows shares the same agent_name.
    • _run_agent_loop_batched re-uses the registered agent loop solely for process_vision_info / apply_chat_template, issues one generate_batched call per group, and post-processes each of the N returned samples through the existing _agent_loop_postprocess so reward / extra-field handling is unchanged.
    • _server_generate_batched is a small module-level helper that goes through LLMServerClient._acquire_server / _release_server so load balancing keeps working without patching upstream verl.
  • verl_omni/workers/config/diffusion/rollout.py, verl_omni/trainer/config/diffusion/rollout/diffusion_rollout.yaml, verl_omni/trainer/config/_generated_diffusion_trainer.yaml

    • New enable_batched_diffusion: bool = True field with inline docs. Generated yaml regenerated via scripts/generate_trainer_config.sh (pre-commit's autogen-trainer-cfg hook passes).
  • examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh

    • Header comment explaining the new flag and an explicit enable_batched_diffusion=True line for visibility.
  • tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py (new)

    • One batched call returns num_outputs_per_prompt valid DiffusionOutputs (shapes, pixel range, log-probs).
    • generate_batched(num_outputs_per_prompt = 1) matches the shape of the legacy generate.
    • num_outputs_per_prompt < 1 raises ValueError before touching the engine.

Why default-on is safe

  • _can_use_batched_path is conservative: it only triggers when every precondition holds (flag, engine name, divisibility, agent-name uniformity). Any deviation falls back to the legacy path.
  • The downstream DataProto produced by _postprocess is shape-identical to the legacy path; the repeat(interleave=True) row layout is preserved sample-by-sample, so reward computation, log-prob handling, and trainer code (ray_diffusion_trainer.py lines 931 / 946) require no changes.
  • The legacy code path stays in place and is unit-tested by the existing tests/agent_loop/test_diffusion_agent_loop.py.

Compatibility

  • No CLI / public-method signature changes for existing call sites. generate still returns DiffusionOutput and runs as B = 1.
  • New config field has a safe default; old configs continue to work.
  • No changes to vllm_omni, verl, trainer, or worker shape contracts.

Test plan

  • pre-commit run --files <all touched paths> — all hooks pass (ruff, ruff-format, mypy, license, docstring coverage, doc-time-info, device-api-usage, dataproto-usage, validate-structure, compileall, autogen-trainer-cfg). The pre-existing check-naming-conventions failure (its grep excludes venv but not .venv site-packages installed at the repo root) is unrelated to this PR and was skipped for this commit only.
  • AST parse + import sanity for every touched Python file (python3 -c "import ast; ast.parse(open(p).read())" for each).
  • pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py -v -s — requires a GPU host with the tiny Qwen-Image fixture (~/models/tiny-random/Qwen-Image). Reviewer should run on their workstation before merging.
  • End-to-end smoke run of examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh with enable_batched_diffusion=True to confirm trainer parity.

Benchmarked perf (from PR #2)

Measured Qwen-Image, 512×512, num_inference_steps=50, rollout.n=16, num_prompts=2 → 32 images, bf16, FA3 for vllm-omni / TORCH_SDPA for diffusers:

Mode Throughput
vllm-omni concurrent (16 × B=1, legacy) 0.403 img/s
vllm-omni batched (1 × B=16, this PR) 0.655 img/s
diffusers (1 × B=16) 0.528 img/s

vllm-omni in the new batched mode is +62% vs the legacy concurrent path and +24% vs diffusers on this workload.

Duplicate-work check

Per AGENTS.md:

AI assistance disclosure

This change was prepared with AI assistance. The submitter has reviewed every modified line, defends the design end-to-end, and verified the new batched path via the benchmark companion in PR #2.

…O agent loop

Adds end-to-end support for running a single B=N transformer forward per
prompt in the vllm-omni diffusion rollout, instead of fanning out N
concurrent B=1 requests that AsyncOmniDiffusion's
ThreadPoolExecutor(max_workers=1) then serializes.

The feature is gated by a single config flag,
``actor_rollout_ref.rollout.enable_batched_diffusion`` (default: True).
When the flag is on and ``rollout.n > 1``, the diffusion agent loop
collapses each group of ``rollout.n`` interleaved rows into one
``generate_batched`` engine request with
``num_outputs_per_prompt = rollout.n``. When off, the legacy per-sample
fan-out path is used unchanged.

Server side
- ``vLLMOmniHttpServer.generate_batched`` (new Ray-callable): submits one
  request with ``OmniDiffusionSamplingParams.num_outputs_per_prompt = N``
  and unpacks the resulting batched ``OmniRequestOutput`` (PIL image list
  + batched custom-output tensors) into ``list[DiffusionOutput]``.
- ``vLLMOmniHttpServer.generate`` is now a thin wrapper around the new
  shared ``_generate_engine_call`` helper with N=1 and returns a single
  ``DiffusionOutput`` exactly as before (fully backwards compatible).

Agent loop side
- ``DiffusionAgentLoopWorker.generate_sequences`` now checks
  ``_can_use_batched_path`` and dispatches to ``_run_agent_loops_batched``
  when the flag is on, ``rollout.n > 1``, and the batch length divides
  cleanly by N (the trainer always ``DataProto.repeat(... interleave=True)``s
  before calling us, so adjacent rows share the same prompt).
- ``_run_agent_loop_batched`` re-uses the registered agent loop for
  tokenization / vision-info extraction, issues one
  ``server.generate_batched.remote`` call, and post-processes each of the
  N returned samples with the existing ``_agent_loop_postprocess`` so
  reward computation / extra-field handling is unchanged.
- ``_server_generate_batched`` is a small helper that goes through
  ``LLMServerClient``'s ``_acquire_server`` / ``_release_server`` pair so
  load balancing still works without modifying upstream verl.

Config / docs
- ``DiffusionRolloutConfig.enable_batched_diffusion: bool = True``
  (mirrored in ``diffusion_rollout.yaml`` and regenerated in
  ``_generated_diffusion_trainer.yaml``).
- ``examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh`` carries an
  inline header explaining the flag plus an explicit
  ``enable_batched_diffusion=True`` line. The two sibling scripts
  (``..._lora_sp2.sh``, ``..._lora_async_reward.sh``) automatically pick
  up the flag from the dataclass default.

Tests
- ``tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py``
  covers:
  - One batched call returns ``num_outputs_per_prompt`` valid per-sample
    ``DiffusionOutput`` records (shapes, pixel range, log-probs).
  - ``generate_batched(num_outputs_per_prompt=1)`` produces the same
    shape as the legacy ``generate``.
  - ``num_outputs_per_prompt < 1`` raises ``ValueError`` before touching
    the engine.

Pre-commit
- ``pre-commit run --files`` passes for every touched path. The
  unrelated ``check-naming-conventions`` failure (its grep excludes
  ``venv`` but not ``.venv`` site-packages) is pre-existing and skipped
  for this commit only.

AI assistance
- This change was prepared with AI assistance; the submitter has
  reviewed every line and verified the new batched path via the
  companion benchmark in PR #2 (vllm-omni batched 0.655 img/s vs
  vllm-omni concurrent 0.403 img/s on Qwen-Image @ 512x512, n=16).

Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
@SamitHuang SamitHuang requested a review from zhtmike as a code owner May 14, 2026 12:36
@SamitHuang SamitHuang marked this pull request as draft May 14, 2026 12:36
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a batched diffusion rollout optimization for the vllm_omni server, aimed at increasing throughput for FlowGRPO-style training. The changes introduce a generate_batched method in the vLLM server that processes multiple samples in a single transformer forward pass and updates the DiffusionAgentLoop to group identical prompts into these batched requests. Additionally, configuration defaults were updated and new E2E tests were added. Reviewers identified a potential issue where prompt identity isn't fully verified before batching and noted a hardcoded num_turns value that should be derived dynamically to ensure robustness.

Comment on lines +252 to +255
for start in range(0, len(batch), rollout_n):
group = agent_names[start : start + rollout_n]
if not all(name == group[0] for name in group):
return False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation only verifies that agent_name is uniform within a group. To ensure correctness when collapsing multiple rows into a single batched request, you should also verify that the raw_prompt values are identical. If different prompts are interleaved in the batch, using the first prompt for the entire group would lead to incorrect generation results for the other samples.

        raw_prompts = batch.non_tensor_batch.get("raw_prompt")
        if raw_prompts is None:
            return False
        for start in range(0, len(batch), rollout_n):
            group_names = agent_names[start : start + rollout_n]
            if not all(name == group_names[0] for name in group_names):
                return False
            group_prompts = raw_prompts[start : start + rollout_n]
            if not all(p == group_prompts[0] for p in group_prompts):
                return False

prompt_ids=prompt_ids,
response_diffusion_output=diffusion_output.diffusion_output,
response_logprobs=diffusion_output.log_probs,
num_turns=2,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value for num_turns is hardcoded to 2. While this is correct for standard single-turn interactions (User + Assistant), it makes an assumption that might not hold for all agent loops. Consider deriving this value from the agent_loop instance or adding a check to ensure the agent loop is indeed single-turn to avoid masking potential logic errors with a default value.

References
  1. Do not add default values at the call site to work around downstream errors. This can mask the underlying issue and lead to silent bugs. Instead, let the code fail fast to make the error explicit.

Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us make vllm-omni functional well with continuous batching instead of doing patch here.

The rollout engine should do its job.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants