Skip to content

[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO agent loop#3

Open
SamitHuang wants to merge 1 commit into
mainfrom
feat/diffusion-batched-flowgrpo
Open

[rollout, omni] feat: enable batched B=N diffusion rollout in FlowGRPO agent loop#3
SamitHuang wants to merge 1 commit into
mainfrom
feat/diffusion-batched-flowgrpo

Conversation

@SamitHuang
Copy link
Copy Markdown
Owner

Summary

End-to-end FlowGRPO integration for batched (B = N) diffusion rollout, so a single vllm-omni request per prompt replaces the existing rollout.n concurrent B = 1 requests that AsyncOmniDiffusion's ThreadPoolExecutor(max_workers=1) would otherwise serialize. Gated by a single config flag, default on.

This PR supersedes the earlier "expose API only" change in #2: same server-side feature, without the standalone benchmark script, plus the agent-loop wiring, a config knob, an example doc, and a focused test.

How to enable in end-to-end FlowGRPO training

The new config knob is:

actor_rollout_ref.rollout.enable_batched_diffusion  # default: True
  • Default (True): when rollout.n > 1, each prompt group of rollout.n interleaved rows is collapsed into one generate_batched(num_outputs_per_prompt = rollout.n) engine call, i.e. a single B = rollout.n transformer forward.
  • Opt-out (False): the agent loop falls back to the legacy per-sample fan-out (rollout.n concurrent B = 1 requests).

examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh shows the flag inline; the two sibling LoRA scripts (..._lora_sp2.sh, ..._lora_async_reward.sh) automatically pick up the dataclass default.

Changes

  • verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py

    • New Ray-callable generate_batched(num_outputs_per_prompt: int) returning list[DiffusionOutput].
    • New private _generate_engine_call helper shared with generate.
    • New _unpack_batched_result slices the batched OmniRequestOutput (PIL list + batched custom-output tensors: latents / log-probs / prompt-embeds / masks) into per-sample DiffusionOutput objects.
    • generate is now a thin wrapper around _generate_engine_call(num_outputs_per_prompt = 1) → fully backwards compatible.
  • verl_omni/agent_loop/diffusion_agent_loop.py

    • generate_sequences dispatches to _run_agent_loops_batched when _can_use_batched_path is true.
    • _can_use_batched_path activates only when: flag is on, rollout.name == "vllm_omni", rollout_n > 1, len(batch) % rollout_n == 0, and every group of rollout_n rows shares the same agent_name.
    • _run_agent_loop_batched re-uses the registered agent loop solely for process_vision_info / apply_chat_template, issues one generate_batched call per group, and post-processes each of the N returned samples through the existing _agent_loop_postprocess so reward / extra-field handling is unchanged.
    • _server_generate_batched is a small module-level helper that goes through LLMServerClient._acquire_server / _release_server so load balancing keeps working without patching upstream verl.
  • verl_omni/workers/config/diffusion/rollout.py, verl_omni/trainer/config/diffusion/rollout/diffusion_rollout.yaml, verl_omni/trainer/config/_generated_diffusion_trainer.yaml

    • New enable_batched_diffusion: bool = True field with inline docs. Generated yaml regenerated via scripts/generate_trainer_config.sh (pre-commit's autogen-trainer-cfg hook passes).
  • examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh

    • Header comment explaining the new flag and an explicit enable_batched_diffusion=True line for visibility.
  • tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py (new)

    • One batched call returns num_outputs_per_prompt valid DiffusionOutputs (shapes, pixel range, log-probs).
    • generate_batched(num_outputs_per_prompt = 1) matches the shape of the legacy generate.
    • num_outputs_per_prompt < 1 raises ValueError before touching the engine.

Why default-on is safe

  • _can_use_batched_path is conservative: it only triggers when every precondition holds (flag, engine name, divisibility, agent-name uniformity). Any deviation falls back to the legacy path.
  • The downstream DataProto produced by _postprocess is shape-identical to the legacy path; the repeat(interleave=True) row layout is preserved sample-by-sample, so reward computation, log-prob handling, and trainer code (ray_diffusion_trainer.py lines 931 / 946) require no changes.
  • The legacy code path stays in place and is unit-tested by the existing tests/agent_loop/test_diffusion_agent_loop.py.

Compatibility

  • No CLI / public-method signature changes for existing call sites. generate still returns DiffusionOutput and runs as B = 1.
  • New config field has a safe default; old configs continue to work.
  • No changes to vllm_omni, verl, trainer, or worker shape contracts.

Test plan

  • pre-commit run --files <all touched paths> — all hooks pass (ruff, ruff-format, mypy, license, docstring coverage, doc-time-info, device-api-usage, dataproto-usage, validate-structure, compileall, autogen-trainer-cfg). The pre-existing check-naming-conventions failure (its grep excludes venv but not .venv site-packages installed at the repo root) is unrelated to this PR and was skipped for this commit only.
  • AST parse + import sanity for every touched Python file (python3 -c "import ast; ast.parse(open(p).read())" for each).
  • pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py -v -s — requires a GPU host with the tiny Qwen-Image fixture (~/models/tiny-random/Qwen-Image). Reviewer should run on their workstation before merging.
  • End-to-end smoke run of examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh with enable_batched_diffusion=True to confirm trainer parity.

Benchmarked perf (from PR #2)

Measured Qwen-Image, 512×512, num_inference_steps=50, rollout.n=16, num_prompts=2 → 32 images, bf16, FA3 for vllm-omni / TORCH_SDPA for diffusers:

Mode Throughput
vllm-omni concurrent (16 × B=1, legacy) 0.403 img/s
vllm-omni batched (1 × B=16, this PR) 0.655 img/s
diffusers (1 × B=16) 0.528 img/s

vllm-omni in the new batched mode is +62% vs the legacy concurrent path and +24% vs diffusers on this workload.

Duplicate-work check

Per AGENTS.md:

AI assistance disclosure

This change was prepared with AI assistance. The submitter has reviewed every modified line, defends the design end-to-end, and verified the new batched path via the benchmark companion in PR #2.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a batched diffusion rollout path for the vllm_omni engine, significantly improving throughput for FlowGRPO-style training by collapsing multiple requests for the same prompt into a single transformer forward pass. Key changes include the addition of generate_batched to the vLLMOmniHttpServer, logic in DiffusionAgentLoopWorker to identify and group identical prompts, and a new end-to-end test suite for batched generation. Feedback highlights a potential correctness issue where prompt identity isn't fully verified before batching, a performance bottleneck caused by sequential post-processing of results, and a maintenance risk regarding the use of private LLMServerClient methods. Additionally, the hardcoded num_turns value in the batched path may cause issues if multi-turn interactions are introduced.

Comment on lines +231 to +256
def _can_use_batched_path(self, batch: DataProto, rollout_n: int) -> bool:
"""Decide whether ``generate_sequences`` can dispatch via ``generate_batched``.

The trainer expands each prompt with ``DataProto.repeat(repeat_times=n,
interleave=True)`` before calling us, so consecutive groups of
``rollout_n`` rows share the same prompt (and same ``agent_name``).
We only collapse a group when (a) the rollout server is ``vllm_omni``,
(b) the flag is enabled, (c) ``rollout_n > 1`` and divides the batch
cleanly, and (d) every row in the group uses the same ``agent_name``.
"""
if not bool(getattr(self.rollout_config, "enable_batched_diffusion", False)):
return False
if rollout_n <= 1:
return False
if getattr(self.rollout_config, "name", None) != "vllm_omni":
return False
if len(batch) == 0 or len(batch) % rollout_n != 0:
return False
agent_names = batch.non_tensor_batch.get("agent_name")
if agent_names is None:
return False
for start in range(0, len(batch), rollout_n):
group = agent_names[start : start + rollout_n]
if not all(name == group[0] for name in group):
return False
return True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic in _can_use_batched_path only verifies that the agent_name is uniform within each group of rollout_n rows. However, _run_agent_loop_batched (line 342) assumes that the entire group shares the same raw_prompt and raw_negative_prompt by only tokenizing the first row's inputs. If the input batch is not correctly interleaved by the trainer, this could lead to silent correctness issues where samples are generated using the wrong prompt. Please add a check to ensure that raw_prompt and raw_negative_prompt are also identical within each group before enabling the batched path.

Comment on lines +373 to +389
for diffusion_output, row_kwargs in zip(diffusion_outputs, row_kwargs_list, strict=True):
# Each sample gets its own metrics dict so per-row timings don't alias
# the same underlying dict and so num_preempted is per-sample.
sample_metrics = dict(metrics)
if sample_metrics.get("num_preempted") is None:
sample_metrics["num_preempted"] = (
diffusion_output.num_preempted if diffusion_output.num_preempted is not None else -1
)
agent_output = DiffusionAgentLoopOutput(
prompt_ids=prompt_ids,
response_diffusion_output=diffusion_output.diffusion_output,
response_logprobs=diffusion_output.log_probs,
num_turns=2,
metrics=sample_metrics,
extra_fields=diffusion_output.extra_fields,
)
results.append(await self._agent_loop_postprocess(agent_output, **row_kwargs))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The post-processing of batched diffusion outputs is currently performed sequentially in a loop. This includes _agent_loop_postprocess, which may involve remote calls for reward computation (line 428). In the legacy non-batched path, these operations were parallelized via asyncio.gather. To avoid a performance bottleneck in the agent loop, consider using asyncio.gather to parallelize the post-processing of all samples in the group.

Suggested change
for diffusion_output, row_kwargs in zip(diffusion_outputs, row_kwargs_list, strict=True):
# Each sample gets its own metrics dict so per-row timings don't alias
# the same underlying dict and so num_preempted is per-sample.
sample_metrics = dict(metrics)
if sample_metrics.get("num_preempted") is None:
sample_metrics["num_preempted"] = (
diffusion_output.num_preempted if diffusion_output.num_preempted is not None else -1
)
agent_output = DiffusionAgentLoopOutput(
prompt_ids=prompt_ids,
response_diffusion_output=diffusion_output.diffusion_output,
response_logprobs=diffusion_output.log_probs,
num_turns=2,
metrics=sample_metrics,
extra_fields=diffusion_output.extra_fields,
)
results.append(await self._agent_loop_postprocess(agent_output, **row_kwargs))
# Parallelize post-processing to match the concurrency of the legacy path
postprocess_tasks = []
for diffusion_output, row_kwargs in zip(diffusion_outputs, row_kwargs_list, strict=True):
sample_metrics = dict(metrics)
if sample_metrics.get("num_preempted") is None:
sample_metrics["num_preempted"] = (
diffusion_output.num_preempted if diffusion_output.num_preempted is not None else -1
)
agent_output = DiffusionAgentLoopOutput(
prompt_ids=prompt_ids,
response_diffusion_output=diffusion_output.diffusion_output,
response_logprobs=diffusion_output.log_probs,
num_turns=2,
metrics=sample_metrics,
extra_fields=diffusion_output.extra_fields,
)
postprocess_tasks.append(self._agent_loop_postprocess(agent_output, **row_kwargs))
return await asyncio.gather(*postprocess_tasks)

prompt_ids=prompt_ids,
response_diffusion_output=diffusion_output.diffusion_output,
response_logprobs=diffusion_output.log_probs,
num_turns=2,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The num_turns field is hardcoded to 2 here. While this is correct for single-turn agent loops (system + user), it bypasses the logic inside agent_loop.run which normally determines this value. If a multi-turn agent loop is used with batched diffusion, this will produce incorrect metadata. Consider deriving this value from the agent_loop instance or documenting that batched diffusion currently only supports single-turn interactions.

Comment on lines +50 to +82
async def _server_generate_batched(
server_manager: LLMServerClient,
*,
request_id: str,
prompt_ids: list[int],
sampling_params: dict[str, Any],
num_outputs_per_prompt: int,
image_data: Optional[list[Any]] = None,
video_data: Optional[list[Any]] = None,
negative_prompt_ids: Optional[list[int]] = None,
) -> list[DiffusionOutput]:
"""Invoke ``vLLMOmniHttpServer.generate_batched`` via the LLM server load balancer.

``LLMServerClient`` in upstream ``verl`` only exposes ``generate``, so this
helper acquires a server handle via the existing load-balancer pair and
calls the new ``generate_batched`` method directly on the Ray actor. The
sticky-session behavior of ``LLMServerClient.generate`` is intentionally
bypassed here: each diffusion request is independent and benefits more
from least-loaded routing than prefix-cache stickiness.
"""
server_id, server = await server_manager._acquire_server(request_id)
try:
return await server.generate_batched.remote(
request_id=request_id,
prompt_ids=prompt_ids,
sampling_params=sampling_params,
num_outputs_per_prompt=num_outputs_per_prompt,
image_data=image_data,
video_data=video_data,
negative_prompt_ids=negative_prompt_ids,
)
finally:
server_manager._release_server(server_id)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _server_generate_batched helper relies on private methods of LLMServerClient (_acquire_server and _release_server). While necessary because the upstream verl client does not yet expose a batched API, this creates a maintenance risk if the internal implementation of LLMServerClient changes. It would be safer to eventually upstream the generate_batched capability to verl's LLMServerClient.

@SamitHuang SamitHuang force-pushed the feat/diffusion-batched-flowgrpo branch from 9983d0b to abc5a32 Compare May 14, 2026 12:22
…O agent loop

Adds end-to-end support for running a single B=N transformer forward per
prompt in the vllm-omni diffusion rollout, instead of fanning out N
concurrent B=1 requests that AsyncOmniDiffusion's
ThreadPoolExecutor(max_workers=1) then serializes.

The feature is gated by a single config flag,
``actor_rollout_ref.rollout.enable_batched_diffusion`` (default: True).
When the flag is on and ``rollout.n > 1``, the diffusion agent loop
collapses each group of ``rollout.n`` interleaved rows into one
``generate_batched`` engine request with
``num_outputs_per_prompt = rollout.n``. When off, the legacy per-sample
fan-out path is used unchanged.

Server side
- ``vLLMOmniHttpServer.generate_batched`` (new Ray-callable): submits one
  request with ``OmniDiffusionSamplingParams.num_outputs_per_prompt = N``
  and unpacks the resulting batched ``OmniRequestOutput`` (PIL image list
  + batched custom-output tensors) into ``list[DiffusionOutput]``.
- ``vLLMOmniHttpServer.generate`` is now a thin wrapper around the new
  shared ``_generate_engine_call`` helper with N=1 and returns a single
  ``DiffusionOutput`` exactly as before (fully backwards compatible).

Agent loop side
- ``DiffusionAgentLoopWorker.generate_sequences`` now checks
  ``_can_use_batched_path`` and dispatches to ``_run_agent_loops_batched``
  when the flag is on, ``rollout.n > 1``, and the batch length divides
  cleanly by N (the trainer always ``DataProto.repeat(... interleave=True)``s
  before calling us, so adjacent rows share the same prompt).
- ``_run_agent_loop_batched`` re-uses the registered agent loop for
  tokenization / vision-info extraction, issues one
  ``server.generate_batched.remote`` call, and post-processes each of the
  N returned samples with the existing ``_agent_loop_postprocess`` so
  reward computation / extra-field handling is unchanged.
- ``_server_generate_batched`` is a small helper that goes through
  ``LLMServerClient``'s ``_acquire_server`` / ``_release_server`` pair so
  load balancing still works without modifying upstream verl.

Config / docs
- ``DiffusionRolloutConfig.enable_batched_diffusion: bool = True``
  (mirrored in ``diffusion_rollout.yaml`` and regenerated in
  ``_generated_diffusion_trainer.yaml``).
- ``examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh`` carries an
  inline header explaining the flag plus an explicit
  ``enable_batched_diffusion=True`` line. The two sibling scripts
  (``..._lora_sp2.sh``, ``..._lora_async_reward.sh``) automatically pick
  up the flag from the dataclass default.

Tests
- ``tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py``
  covers:
  - One batched call returns ``num_outputs_per_prompt`` valid per-sample
    ``DiffusionOutput`` records (shapes, pixel range, log-probs).
  - ``generate_batched(num_outputs_per_prompt=1)`` produces the same
    shape as the legacy ``generate``.
  - ``num_outputs_per_prompt < 1`` raises ``ValueError`` before touching
    the engine.

Pre-commit
- ``pre-commit run --files`` passes for every touched path. The
  unrelated ``check-naming-conventions`` failure (its grep excludes
  ``venv`` but not ``.venv`` site-packages) is pre-existing and skipped
  for this commit only.

AI assistance
- This change was prepared with AI assistance; the submitter has
  reviewed every line and verified the new batched path via the
  companion benchmark in PR #2 (vllm-omni batched 0.655 img/s vs
  vllm-omni concurrent 0.403 img/s on Qwen-Image @ 512x512, n=16).

Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
@SamitHuang SamitHuang force-pushed the feat/diffusion-batched-flowgrpo branch from abc5a32 to 3a739f7 Compare May 14, 2026 12:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant