[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO agent loop#3
[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO agent loop#3SamitHuang wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a batched diffusion rollout path for the vllm_omni engine, significantly improving throughput for FlowGRPO-style training by collapsing multiple requests for the same prompt into a single transformer forward pass. Key changes include the addition of generate_batched to the vLLMOmniHttpServer, logic in DiffusionAgentLoopWorker to identify and group identical prompts, and a new end-to-end test suite for batched generation. Feedback highlights a potential correctness issue where prompt identity isn't fully verified before batching, a performance bottleneck caused by sequential post-processing of results, and a maintenance risk regarding the use of private LLMServerClient methods. Additionally, the hardcoded num_turns value in the batched path may cause issues if multi-turn interactions are introduced.
| def _can_use_batched_path(self, batch: DataProto, rollout_n: int) -> bool: | ||
| """Decide whether ``generate_sequences`` can dispatch via ``generate_batched``. | ||
|
|
||
| The trainer expands each prompt with ``DataProto.repeat(repeat_times=n, | ||
| interleave=True)`` before calling us, so consecutive groups of | ||
| ``rollout_n`` rows share the same prompt (and same ``agent_name``). | ||
| We only collapse a group when (a) the rollout server is ``vllm_omni``, | ||
| (b) the flag is enabled, (c) ``rollout_n > 1`` and divides the batch | ||
| cleanly, and (d) every row in the group uses the same ``agent_name``. | ||
| """ | ||
| if not bool(getattr(self.rollout_config, "enable_batched_diffusion", False)): | ||
| return False | ||
| if rollout_n <= 1: | ||
| return False | ||
| if getattr(self.rollout_config, "name", None) != "vllm_omni": | ||
| return False | ||
| if len(batch) == 0 or len(batch) % rollout_n != 0: | ||
| return False | ||
| agent_names = batch.non_tensor_batch.get("agent_name") | ||
| if agent_names is None: | ||
| return False | ||
| for start in range(0, len(batch), rollout_n): | ||
| group = agent_names[start : start + rollout_n] | ||
| if not all(name == group[0] for name in group): | ||
| return False | ||
| return True |
There was a problem hiding this comment.
The current logic in _can_use_batched_path only verifies that the agent_name is uniform within each group of rollout_n rows. However, _run_agent_loop_batched (line 342) assumes that the entire group shares the same raw_prompt and raw_negative_prompt by only tokenizing the first row's inputs. If the input batch is not correctly interleaved by the trainer, this could lead to silent correctness issues where samples are generated using the wrong prompt. Please add a check to ensure that raw_prompt and raw_negative_prompt are also identical within each group before enabling the batched path.
| for diffusion_output, row_kwargs in zip(diffusion_outputs, row_kwargs_list, strict=True): | ||
| # Each sample gets its own metrics dict so per-row timings don't alias | ||
| # the same underlying dict and so num_preempted is per-sample. | ||
| sample_metrics = dict(metrics) | ||
| if sample_metrics.get("num_preempted") is None: | ||
| sample_metrics["num_preempted"] = ( | ||
| diffusion_output.num_preempted if diffusion_output.num_preempted is not None else -1 | ||
| ) | ||
| agent_output = DiffusionAgentLoopOutput( | ||
| prompt_ids=prompt_ids, | ||
| response_diffusion_output=diffusion_output.diffusion_output, | ||
| response_logprobs=diffusion_output.log_probs, | ||
| num_turns=2, | ||
| metrics=sample_metrics, | ||
| extra_fields=diffusion_output.extra_fields, | ||
| ) | ||
| results.append(await self._agent_loop_postprocess(agent_output, **row_kwargs)) |
There was a problem hiding this comment.
The post-processing of batched diffusion outputs is currently performed sequentially in a loop. This includes _agent_loop_postprocess, which may involve remote calls for reward computation (line 428). In the legacy non-batched path, these operations were parallelized via asyncio.gather. To avoid a performance bottleneck in the agent loop, consider using asyncio.gather to parallelize the post-processing of all samples in the group.
| for diffusion_output, row_kwargs in zip(diffusion_outputs, row_kwargs_list, strict=True): | |
| # Each sample gets its own metrics dict so per-row timings don't alias | |
| # the same underlying dict and so num_preempted is per-sample. | |
| sample_metrics = dict(metrics) | |
| if sample_metrics.get("num_preempted") is None: | |
| sample_metrics["num_preempted"] = ( | |
| diffusion_output.num_preempted if diffusion_output.num_preempted is not None else -1 | |
| ) | |
| agent_output = DiffusionAgentLoopOutput( | |
| prompt_ids=prompt_ids, | |
| response_diffusion_output=diffusion_output.diffusion_output, | |
| response_logprobs=diffusion_output.log_probs, | |
| num_turns=2, | |
| metrics=sample_metrics, | |
| extra_fields=diffusion_output.extra_fields, | |
| ) | |
| results.append(await self._agent_loop_postprocess(agent_output, **row_kwargs)) | |
| # Parallelize post-processing to match the concurrency of the legacy path | |
| postprocess_tasks = [] | |
| for diffusion_output, row_kwargs in zip(diffusion_outputs, row_kwargs_list, strict=True): | |
| sample_metrics = dict(metrics) | |
| if sample_metrics.get("num_preempted") is None: | |
| sample_metrics["num_preempted"] = ( | |
| diffusion_output.num_preempted if diffusion_output.num_preempted is not None else -1 | |
| ) | |
| agent_output = DiffusionAgentLoopOutput( | |
| prompt_ids=prompt_ids, | |
| response_diffusion_output=diffusion_output.diffusion_output, | |
| response_logprobs=diffusion_output.log_probs, | |
| num_turns=2, | |
| metrics=sample_metrics, | |
| extra_fields=diffusion_output.extra_fields, | |
| ) | |
| postprocess_tasks.append(self._agent_loop_postprocess(agent_output, **row_kwargs)) | |
| return await asyncio.gather(*postprocess_tasks) |
| prompt_ids=prompt_ids, | ||
| response_diffusion_output=diffusion_output.diffusion_output, | ||
| response_logprobs=diffusion_output.log_probs, | ||
| num_turns=2, |
There was a problem hiding this comment.
The num_turns field is hardcoded to 2 here. While this is correct for single-turn agent loops (system + user), it bypasses the logic inside agent_loop.run which normally determines this value. If a multi-turn agent loop is used with batched diffusion, this will produce incorrect metadata. Consider deriving this value from the agent_loop instance or documenting that batched diffusion currently only supports single-turn interactions.
| async def _server_generate_batched( | ||
| server_manager: LLMServerClient, | ||
| *, | ||
| request_id: str, | ||
| prompt_ids: list[int], | ||
| sampling_params: dict[str, Any], | ||
| num_outputs_per_prompt: int, | ||
| image_data: Optional[list[Any]] = None, | ||
| video_data: Optional[list[Any]] = None, | ||
| negative_prompt_ids: Optional[list[int]] = None, | ||
| ) -> list[DiffusionOutput]: | ||
| """Invoke ``vLLMOmniHttpServer.generate_batched`` via the LLM server load balancer. | ||
|
|
||
| ``LLMServerClient`` in upstream ``verl`` only exposes ``generate``, so this | ||
| helper acquires a server handle via the existing load-balancer pair and | ||
| calls the new ``generate_batched`` method directly on the Ray actor. The | ||
| sticky-session behavior of ``LLMServerClient.generate`` is intentionally | ||
| bypassed here: each diffusion request is independent and benefits more | ||
| from least-loaded routing than prefix-cache stickiness. | ||
| """ | ||
| server_id, server = await server_manager._acquire_server(request_id) | ||
| try: | ||
| return await server.generate_batched.remote( | ||
| request_id=request_id, | ||
| prompt_ids=prompt_ids, | ||
| sampling_params=sampling_params, | ||
| num_outputs_per_prompt=num_outputs_per_prompt, | ||
| image_data=image_data, | ||
| video_data=video_data, | ||
| negative_prompt_ids=negative_prompt_ids, | ||
| ) | ||
| finally: | ||
| server_manager._release_server(server_id) |
There was a problem hiding this comment.
The _server_generate_batched helper relies on private methods of LLMServerClient (_acquire_server and _release_server). While necessary because the upstream verl client does not yet expose a batched API, this creates a maintenance risk if the internal implementation of LLMServerClient changes. It would be safer to eventually upstream the generate_batched capability to verl's LLMServerClient.
9983d0b to
abc5a32
Compare
…O agent loop
Adds end-to-end support for running a single B=N transformer forward per
prompt in the vllm-omni diffusion rollout, instead of fanning out N
concurrent B=1 requests that AsyncOmniDiffusion's
ThreadPoolExecutor(max_workers=1) then serializes.
The feature is gated by a single config flag,
``actor_rollout_ref.rollout.enable_batched_diffusion`` (default: True).
When the flag is on and ``rollout.n > 1``, the diffusion agent loop
collapses each group of ``rollout.n`` interleaved rows into one
``generate_batched`` engine request with
``num_outputs_per_prompt = rollout.n``. When off, the legacy per-sample
fan-out path is used unchanged.
Server side
- ``vLLMOmniHttpServer.generate_batched`` (new Ray-callable): submits one
request with ``OmniDiffusionSamplingParams.num_outputs_per_prompt = N``
and unpacks the resulting batched ``OmniRequestOutput`` (PIL image list
+ batched custom-output tensors) into ``list[DiffusionOutput]``.
- ``vLLMOmniHttpServer.generate`` is now a thin wrapper around the new
shared ``_generate_engine_call`` helper with N=1 and returns a single
``DiffusionOutput`` exactly as before (fully backwards compatible).
Agent loop side
- ``DiffusionAgentLoopWorker.generate_sequences`` now checks
``_can_use_batched_path`` and dispatches to ``_run_agent_loops_batched``
when the flag is on, ``rollout.n > 1``, and the batch length divides
cleanly by N (the trainer always ``DataProto.repeat(... interleave=True)``s
before calling us, so adjacent rows share the same prompt).
- ``_run_agent_loop_batched`` re-uses the registered agent loop for
tokenization / vision-info extraction, issues one
``server.generate_batched.remote`` call, and post-processes each of the
N returned samples with the existing ``_agent_loop_postprocess`` so
reward computation / extra-field handling is unchanged.
- ``_server_generate_batched`` is a small helper that goes through
``LLMServerClient``'s ``_acquire_server`` / ``_release_server`` pair so
load balancing still works without modifying upstream verl.
Config / docs
- ``DiffusionRolloutConfig.enable_batched_diffusion: bool = True``
(mirrored in ``diffusion_rollout.yaml`` and regenerated in
``_generated_diffusion_trainer.yaml``).
- ``examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh`` carries an
inline header explaining the flag plus an explicit
``enable_batched_diffusion=True`` line. The two sibling scripts
(``..._lora_sp2.sh``, ``..._lora_async_reward.sh``) automatically pick
up the flag from the dataclass default.
Tests
- ``tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py``
covers:
- One batched call returns ``num_outputs_per_prompt`` valid per-sample
``DiffusionOutput`` records (shapes, pixel range, log-probs).
- ``generate_batched(num_outputs_per_prompt=1)`` produces the same
shape as the legacy ``generate``.
- ``num_outputs_per_prompt < 1`` raises ``ValueError`` before touching
the engine.
Pre-commit
- ``pre-commit run --files`` passes for every touched path. The
unrelated ``check-naming-conventions`` failure (its grep excludes
``venv`` but not ``.venv`` site-packages) is pre-existing and skipped
for this commit only.
AI assistance
- This change was prepared with AI assistance; the submitter has
reviewed every line and verified the new batched path via the
companion benchmark in PR #2 (vllm-omni batched 0.655 img/s vs
vllm-omni concurrent 0.403 img/s on Qwen-Image @ 512x512, n=16).
Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
abc5a32 to
3a739f7
Compare
Summary
End-to-end FlowGRPO integration for batched (B = N) diffusion rollout, so a single
vllm-omnirequest per prompt replaces the existingrollout.nconcurrent B = 1 requests thatAsyncOmniDiffusion'sThreadPoolExecutor(max_workers=1)would otherwise serialize. Gated by a single config flag, default on.This PR supersedes the earlier "expose API only" change in #2: same server-side feature, without the standalone benchmark script, plus the agent-loop wiring, a config knob, an example doc, and a focused test.
How to enable in end-to-end FlowGRPO training
The new config knob is:
True): whenrollout.n > 1, each prompt group ofrollout.ninterleaved rows is collapsed into onegenerate_batched(num_outputs_per_prompt = rollout.n)engine call, i.e. a single B =rollout.ntransformer forward.False): the agent loop falls back to the legacy per-sample fan-out (rollout.nconcurrent B = 1 requests).examples/flowgrpo_trainer/run_qwen_image_ocr_lora.shshows the flag inline; the two sibling LoRA scripts (..._lora_sp2.sh,..._lora_async_reward.sh) automatically pick up the dataclass default.Changes
verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.pygenerate_batched(num_outputs_per_prompt: int)returninglist[DiffusionOutput]._generate_engine_callhelper shared withgenerate._unpack_batched_resultslices the batchedOmniRequestOutput(PIL list + batched custom-output tensors: latents / log-probs / prompt-embeds / masks) into per-sampleDiffusionOutputobjects.generateis now a thin wrapper around_generate_engine_call(num_outputs_per_prompt = 1)→ fully backwards compatible.verl_omni/agent_loop/diffusion_agent_loop.pygenerate_sequencesdispatches to_run_agent_loops_batchedwhen_can_use_batched_pathis true._can_use_batched_pathactivates only when: flag is on,rollout.name == "vllm_omni",rollout_n > 1,len(batch) % rollout_n == 0, and every group ofrollout_nrows shares the sameagent_name._run_agent_loop_batchedre-uses the registered agent loop solely forprocess_vision_info/apply_chat_template, issues onegenerate_batchedcall per group, and post-processes each of the N returned samples through the existing_agent_loop_postprocessso reward / extra-field handling is unchanged._server_generate_batchedis a small module-level helper that goes throughLLMServerClient._acquire_server/_release_serverso load balancing keeps working without patching upstream verl.verl_omni/workers/config/diffusion/rollout.py,verl_omni/trainer/config/diffusion/rollout/diffusion_rollout.yaml,verl_omni/trainer/config/_generated_diffusion_trainer.yamlenable_batched_diffusion: bool = Truefield with inline docs. Generated yaml regenerated viascripts/generate_trainer_config.sh(pre-commit'sautogen-trainer-cfghook passes).examples/flowgrpo_trainer/run_qwen_image_ocr_lora.shenable_batched_diffusion=Trueline for visibility.tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py(new)num_outputs_per_promptvalidDiffusionOutputs (shapes, pixel range, log-probs).generate_batched(num_outputs_per_prompt = 1)matches the shape of the legacygenerate.num_outputs_per_prompt < 1raisesValueErrorbefore touching the engine.Why default-on is safe
_can_use_batched_pathis conservative: it only triggers when every precondition holds (flag, engine name, divisibility, agent-name uniformity). Any deviation falls back to the legacy path.DataProtoproduced by_postprocessis shape-identical to the legacy path; therepeat(interleave=True)row layout is preserved sample-by-sample, so reward computation, log-prob handling, and trainer code (ray_diffusion_trainer.pylines 931 / 946) require no changes.tests/agent_loop/test_diffusion_agent_loop.py.Compatibility
generatestill returnsDiffusionOutputand runs as B = 1.vllm_omni,verl, trainer, or worker shape contracts.Test plan
pre-commit run --files <all touched paths>— all hooks pass (ruff, ruff-format, mypy, license, docstring coverage, doc-time-info, device-api-usage, dataproto-usage, validate-structure, compileall, autogen-trainer-cfg). The pre-existingcheck-naming-conventionsfailure (its grep excludesvenvbut not.venvsite-packages installed at the repo root) is unrelated to this PR and was skipped for this commit only.python3 -c "import ast; ast.parse(open(p).read())"for each).pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py -v -s— requires a GPU host with the tiny Qwen-Image fixture (~/models/tiny-random/Qwen-Image). Reviewer should run on their workstation before merging.examples/flowgrpo_trainer/run_qwen_image_ocr_lora.shwithenable_batched_diffusion=Trueto confirm trainer parity.Benchmarked perf (from PR #2)
Measured Qwen-Image, 512×512,
num_inference_steps=50,rollout.n=16,num_prompts=2→ 32 images, bf16, FA3 for vllm-omni / TORCH_SDPA for diffusers:vllm-omni in the new batched mode is +62% vs the legacy concurrent path and +24% vs
diffuserson this workload.Duplicate-work check
Per
AGENTS.md:gh api search/issues -f q="repo:verl-project/verl-omni is:pr is:open vllm-omni rollout batch"→ 4 hits ([diffusion, rollout, trainer] feat: add BAGEL FlowGRPO support verl-project/verl-omni#66 BAGEL FlowGRPO, [Bugfix] Enable step-wise execution verl-project/verl-omni#81 step-wise execution, [Algo] DPO (online) training with SD3.5-medium verl-project/verl-omni#77 SD3.5 DPO, [trainer, diffusion] feat: add z-image support for flowgrpo training verl-project/verl-omni#76 z-image FlowGRPO). None of them touch the diffusion server's request-batching API or the FlowGRPO agent-loop dispatch.gh pr list --repo SamitHuang/verl-omni --state open→ only PR [rollout, vllm_omni] feat: add batched diffusion generate API (one B=N forward instead of N B=1) #2 (the bench-script counterpart). This PR replaces [rollout, vllm_omni] feat: add batched diffusion generate API (one B=N forward instead of N B=1) #2 with the production-grade wiring + tests and no benchmark script, per reviewer request.AI assistance disclosure
This change was prepared with AI assistance. The submitter has reviewed every modified line, defends the design end-to-end, and verified the new batched path via the benchmark companion in PR #2.