[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO#83
[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO#83SamitHuang wants to merge 1 commit into
Conversation
…O agent loop
Adds end-to-end support for running a single B=N transformer forward per
prompt in the vllm-omni diffusion rollout, instead of fanning out N
concurrent B=1 requests that AsyncOmniDiffusion's
ThreadPoolExecutor(max_workers=1) then serializes.
The feature is gated by a single config flag,
``actor_rollout_ref.rollout.enable_batched_diffusion`` (default: True).
When the flag is on and ``rollout.n > 1``, the diffusion agent loop
collapses each group of ``rollout.n`` interleaved rows into one
``generate_batched`` engine request with
``num_outputs_per_prompt = rollout.n``. When off, the legacy per-sample
fan-out path is used unchanged.
Server side
- ``vLLMOmniHttpServer.generate_batched`` (new Ray-callable): submits one
request with ``OmniDiffusionSamplingParams.num_outputs_per_prompt = N``
and unpacks the resulting batched ``OmniRequestOutput`` (PIL image list
+ batched custom-output tensors) into ``list[DiffusionOutput]``.
- ``vLLMOmniHttpServer.generate`` is now a thin wrapper around the new
shared ``_generate_engine_call`` helper with N=1 and returns a single
``DiffusionOutput`` exactly as before (fully backwards compatible).
Agent loop side
- ``DiffusionAgentLoopWorker.generate_sequences`` now checks
``_can_use_batched_path`` and dispatches to ``_run_agent_loops_batched``
when the flag is on, ``rollout.n > 1``, and the batch length divides
cleanly by N (the trainer always ``DataProto.repeat(... interleave=True)``s
before calling us, so adjacent rows share the same prompt).
- ``_run_agent_loop_batched`` re-uses the registered agent loop for
tokenization / vision-info extraction, issues one
``server.generate_batched.remote`` call, and post-processes each of the
N returned samples with the existing ``_agent_loop_postprocess`` so
reward computation / extra-field handling is unchanged.
- ``_server_generate_batched`` is a small helper that goes through
``LLMServerClient``'s ``_acquire_server`` / ``_release_server`` pair so
load balancing still works without modifying upstream verl.
Config / docs
- ``DiffusionRolloutConfig.enable_batched_diffusion: bool = True``
(mirrored in ``diffusion_rollout.yaml`` and regenerated in
``_generated_diffusion_trainer.yaml``).
- ``examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh`` carries an
inline header explaining the flag plus an explicit
``enable_batched_diffusion=True`` line. The two sibling scripts
(``..._lora_sp2.sh``, ``..._lora_async_reward.sh``) automatically pick
up the flag from the dataclass default.
Tests
- ``tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py``
covers:
- One batched call returns ``num_outputs_per_prompt`` valid per-sample
``DiffusionOutput`` records (shapes, pixel range, log-probs).
- ``generate_batched(num_outputs_per_prompt=1)`` produces the same
shape as the legacy ``generate``.
- ``num_outputs_per_prompt < 1`` raises ``ValueError`` before touching
the engine.
Pre-commit
- ``pre-commit run --files`` passes for every touched path. The
unrelated ``check-naming-conventions`` failure (its grep excludes
``venv`` but not ``.venv`` site-packages) is pre-existing and skipped
for this commit only.
AI assistance
- This change was prepared with AI assistance; the submitter has
reviewed every line and verified the new batched path via the
companion benchmark in PR #2 (vllm-omni batched 0.655 img/s vs
vllm-omni concurrent 0.403 img/s on Qwen-Image @ 512x512, n=16).
Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
There was a problem hiding this comment.
Code Review
This pull request implements a batched diffusion rollout optimization for the vllm_omni server, aimed at increasing throughput for FlowGRPO-style training. The changes introduce a generate_batched method in the vLLM server that processes multiple samples in a single transformer forward pass and updates the DiffusionAgentLoop to group identical prompts into these batched requests. Additionally, configuration defaults were updated and new E2E tests were added. Reviewers identified a potential issue where prompt identity isn't fully verified before batching and noted a hardcoded num_turns value that should be derived dynamically to ensure robustness.
| for start in range(0, len(batch), rollout_n): | ||
| group = agent_names[start : start + rollout_n] | ||
| if not all(name == group[0] for name in group): | ||
| return False |
There was a problem hiding this comment.
The current implementation only verifies that agent_name is uniform within a group. To ensure correctness when collapsing multiple rows into a single batched request, you should also verify that the raw_prompt values are identical. If different prompts are interleaved in the batch, using the first prompt for the entire group would lead to incorrect generation results for the other samples.
raw_prompts = batch.non_tensor_batch.get("raw_prompt")
if raw_prompts is None:
return False
for start in range(0, len(batch), rollout_n):
group_names = agent_names[start : start + rollout_n]
if not all(name == group_names[0] for name in group_names):
return False
group_prompts = raw_prompts[start : start + rollout_n]
if not all(p == group_prompts[0] for p in group_prompts):
return False| prompt_ids=prompt_ids, | ||
| response_diffusion_output=diffusion_output.diffusion_output, | ||
| response_logprobs=diffusion_output.log_probs, | ||
| num_turns=2, |
There was a problem hiding this comment.
The value for num_turns is hardcoded to 2. While this is correct for standard single-turn interactions (User + Assistant), it makes an assumption that might not hold for all agent loops. Consider deriving this value from the agent_loop instance or adding a check to ensure the agent loop is indeed single-turn to avoid masking potential logic errors with a default value.
References
- Do not add default values at the call site to work around downstream errors. This can mask the underlying issue and lead to silent bugs. Instead, let the code fail fast to make the error explicit.
zhtmike
left a comment
There was a problem hiding this comment.
Let us make vllm-omni functional well with continuous batching instead of doing patch here.
The rollout engine should do its job.
Summary
End-to-end FlowGRPO integration for batched (B = N) diffusion rollout, so a single
vllm-omnirequest per prompt replaces the existingrollout.nconcurrent B = 1 requests thatAsyncOmniDiffusion'sThreadPoolExecutor(max_workers=1)would otherwise serialize. Gated by a single config flag, default on.This PR supersedes the earlier "expose API only" change in #2: same server-side feature, without the standalone benchmark script, plus the agent-loop wiring, a config knob, an example doc, and a focused test.
How to enable in end-to-end FlowGRPO training
The new config knob is:
True): whenrollout.n > 1, each prompt group ofrollout.ninterleaved rows is collapsed into onegenerate_batched(num_outputs_per_prompt = rollout.n)engine call, i.e. a single B =rollout.ntransformer forward.False): the agent loop falls back to the legacy per-sample fan-out (rollout.nconcurrent B = 1 requests).examples/flowgrpo_trainer/run_qwen_image_ocr_lora.shshows the flag inline; the two sibling LoRA scripts (..._lora_sp2.sh,..._lora_async_reward.sh) automatically pick up the dataclass default.Changes
verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.pygenerate_batched(num_outputs_per_prompt: int)returninglist[DiffusionOutput]._generate_engine_callhelper shared withgenerate._unpack_batched_resultslices the batchedOmniRequestOutput(PIL list + batched custom-output tensors: latents / log-probs / prompt-embeds / masks) into per-sampleDiffusionOutputobjects.generateis now a thin wrapper around_generate_engine_call(num_outputs_per_prompt = 1)→ fully backwards compatible.verl_omni/agent_loop/diffusion_agent_loop.pygenerate_sequencesdispatches to_run_agent_loops_batchedwhen_can_use_batched_pathis true._can_use_batched_pathactivates only when: flag is on,rollout.name == "vllm_omni",rollout_n > 1,len(batch) % rollout_n == 0, and every group ofrollout_nrows shares the sameagent_name._run_agent_loop_batchedre-uses the registered agent loop solely forprocess_vision_info/apply_chat_template, issues onegenerate_batchedcall per group, and post-processes each of the N returned samples through the existing_agent_loop_postprocessso reward / extra-field handling is unchanged._server_generate_batchedis a small module-level helper that goes throughLLMServerClient._acquire_server/_release_serverso load balancing keeps working without patching upstream verl.verl_omni/workers/config/diffusion/rollout.py,verl_omni/trainer/config/diffusion/rollout/diffusion_rollout.yaml,verl_omni/trainer/config/_generated_diffusion_trainer.yamlenable_batched_diffusion: bool = Truefield with inline docs. Generated yaml regenerated viascripts/generate_trainer_config.sh(pre-commit'sautogen-trainer-cfghook passes).examples/flowgrpo_trainer/run_qwen_image_ocr_lora.shenable_batched_diffusion=Trueline for visibility.tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py(new)num_outputs_per_promptvalidDiffusionOutputs (shapes, pixel range, log-probs).generate_batched(num_outputs_per_prompt = 1)matches the shape of the legacygenerate.num_outputs_per_prompt < 1raisesValueErrorbefore touching the engine.Why default-on is safe
_can_use_batched_pathis conservative: it only triggers when every precondition holds (flag, engine name, divisibility, agent-name uniformity). Any deviation falls back to the legacy path.DataProtoproduced by_postprocessis shape-identical to the legacy path; therepeat(interleave=True)row layout is preserved sample-by-sample, so reward computation, log-prob handling, and trainer code (ray_diffusion_trainer.pylines 931 / 946) require no changes.tests/agent_loop/test_diffusion_agent_loop.py.Compatibility
generatestill returnsDiffusionOutputand runs as B = 1.vllm_omni,verl, trainer, or worker shape contracts.Test plan
pre-commit run --files <all touched paths>— all hooks pass (ruff, ruff-format, mypy, license, docstring coverage, doc-time-info, device-api-usage, dataproto-usage, validate-structure, compileall, autogen-trainer-cfg). The pre-existingcheck-naming-conventionsfailure (its grep excludesvenvbut not.venvsite-packages installed at the repo root) is unrelated to this PR and was skipped for this commit only.python3 -c "import ast; ast.parse(open(p).read())"for each).pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py -v -s— requires a GPU host with the tiny Qwen-Image fixture (~/models/tiny-random/Qwen-Image). Reviewer should run on their workstation before merging.examples/flowgrpo_trainer/run_qwen_image_ocr_lora.shwithenable_batched_diffusion=Trueto confirm trainer parity.Benchmarked perf (from PR #2)
Measured Qwen-Image, 512×512,
num_inference_steps=50,rollout.n=16,num_prompts=2→ 32 images, bf16, FA3 for vllm-omni / TORCH_SDPA for diffusers:vllm-omni in the new batched mode is +62% vs the legacy concurrent path and +24% vs
diffuserson this workload.Duplicate-work check
Per
AGENTS.md:gh api search/issues -f q="repo:verl-project/verl-omni is:pr is:open vllm-omni rollout batch"→ 4 hits ([diffusion, rollout, trainer] feat: add BAGEL FlowGRPO support #66 BAGEL FlowGRPO, [Bugfix] Enable step-wise execution #81 step-wise execution, [Algo] DPO (online) training with SD3.5-medium #77 SD3.5 DPO, [trainer, diffusion] feat: add z-image support for flowgrpo training #76 z-image FlowGRPO). None of them touch the diffusion server's request-batching API or the FlowGRPO agent-loop dispatch.gh pr list --repo SamitHuang/verl-omni --state open→ only PR [doc] chore: supply documentation for flowgrpo training #2 (the bench-script counterpart). This PR replaces [doc] chore: supply documentation for flowgrpo training #2 with the production-grade wiring + tests and no benchmark script, per reviewer request.AI assistance disclosure
This change was prepared with AI assistance. The submitter has reviewed every modified line, defends the design end-to-end, and verified the new batched path via the benchmark companion in PR #2.