Skip to content

[rollout, vllm_omni] feat: add batched diffusion generate API (one B=N forward instead of N B=1)#2

Open
SamitHuang wants to merge 4 commits into
mainfrom
feat/vllm-omni-diffusion-batched-generate
Open

[rollout, vllm_omni] feat: add batched diffusion generate API (one B=N forward instead of N B=1)#2
SamitHuang wants to merge 4 commits into
mainfrom
feat/vllm-omni-diffusion-batched-generate

Conversation

@SamitHuang
Copy link
Copy Markdown
Owner

Summary

Adds vLLMOmniHttpServer.generate_batched so a single Ray request can ask the vLLM-Omni diffusion engine to produce num_outputs_per_prompt samples in one QwenImagePipelineWithLogProb forward pass — the same way diffusers.QwenImagePipeline handles num_images_per_prompt=N.

Existing generate behavior is preserved bit-for-bit: it now delegates to a shared _generate_engine_call helper with num_outputs_per_prompt=1 and returns the first DiffusionOutput.

Motivation

In the current FlowGRPO Qwen-Image rollout (e.g. examples/flowgrpo_trainer/run_qwen_image_ocr.sh with rollout.n=16):

  • The agent loop fans out N concurrent generate calls per prompt.
  • Each call submits a B=1 request to vllm-omni.
  • AsyncOmniDiffusion then serializes them via ThreadPoolExecutor(max_workers=1), so the GPU effectively runs N sequential B=1 diffusion forwards.
  • Plain diffusers runs a single B=N forward and ends up faster on the same workload despite using a slower attention backend.

generate_batched lets the rollout client issue a single request that asks the pipeline to do one B=N transformer forward, then receive N DiffusionOutput objects back.

Changes

  • verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py

    • Factored shared engine-call logic into private _generate_engine_call (takes num_outputs_per_prompt).
    • generate is now a thin wrapper around it with num_outputs_per_prompt=1 → fully backwards compatible.
    • New generate_batched returns list[DiffusionOutput] by setting OmniDiffusionSamplingParams.num_outputs_per_prompt=N.
    • New _unpack_batched_result slices the batched OmniRequestOutput (PIL image list + batched custom-output tensors: latents / log-probs / prompt-embeds / masks) into per-sample DiffusionOutputs, preserving the contract of the existing single-image path.
    • Ruff isort re-grouped the vllm_omni.* / vllm.* imports below verl.* (auto-applied by the project's pre-commit config).
  • scripts/bench_qwen_image_rollout.py (new)

    • Apples-to-apples T2I benchmark comparing vLLM-Omni and diffusers.QwenImagePipeline at 512×512.
    • Workload, dtype, max_sequence_length, num_inference_steps, and true_cfg_scale lifted from examples/flowgrpo_trainer/run_qwen_image_ocr.sh; resolution overridden to 512 per the original task.
    • --vllm-omni-mode {concurrent,batched} toggles between the legacy N × B=1 submission pattern and the new 1 × B=N path through generate_batched.
    • Detects and logs the active diffusion attention backend for each engine (vllm-omni: DIFFUSION_ATTENTION_BACKEND selection; diffusers: _AttentionBackendRegistry._active_backend + SDPA dispatch info).
    • Reports setup / warmup / steady-state throughput and peak GPU memory; supports --backend {vllm_omni,diffusers,both} and --output bench_results.json.

Compatibility

  • No CLI / config / public-method signature changes for existing call sites. generate still returns DiffusionOutput and still runs as B=1.
  • generate_batched is a new method; existing callers are not affected.
  • No changes to vllm_omni, verl, trainer, or worker configs.

Test plan

  • pre-commit run --files verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py scripts/bench_qwen_image_rollout.py — all hooks pass except check-naming-conventions, which fails on the unchanged tree (its grep excludes venv but not .venv site-packages installed at the repo root). Skipped for this commit only and unrelated to this PR.

  • Benchmark sanity run on this workstation (Qwen-Image, 512×512, num_inference_steps=50, rollout_n=16, num_prompts=2 → 32 images, bf16, FA3 for vllm-omni / TORCH_SDPA for diffusers):

    Backend Throughput Note
    vllm-omni --vllm-omni-mode concurrent 0.403 img/s 16 × B=1, matches current agent-loop path
    vllm-omni --vllm-omni-mode batched 0.655 img/s 1 × B=16 via new generate_batched
    diffusers 0.528 img/s 1 × B=16

    vllm-omni batched is +62% vs vllm-omni concurrent and +24% vs diffusers on this workload.

  • Wire generate_batched into the FlowGRPO agent loop / rollout client (follow-up PR; this change only exposes the API).

Duplicate-work check

Per AGENTS.md, searched open PRs on verl-project/verl-omni:

AI assistance disclosure

This change was prepared with AI assistance. The submitter has reviewed every modified line, defends the design end-to-end, and ran the benchmark above to confirm the speedup.

Signed-off-by: SamitHuang <285365963@qq.com>
…N forward instead of N B=1)

Adds ``vLLMOmniHttpServer.generate_batched`` so a single Ray request can ask
the vLLM-Omni diffusion engine to produce ``num_outputs_per_prompt`` samples
in one ``QwenImagePipelineWithLogProb`` forward pass, mirroring how
``diffusers.QwenImagePipeline`` handles ``num_images_per_prompt``.

Why this matters
- The existing ``generate`` path always submits a B=1 request. When the
  rollout client (e.g. FlowGRPO with rollout.n=16) fans out ``N`` concurrent
  ``generate`` calls per prompt, those calls are serialized by
  ``AsyncOmniDiffusion``'s ``ThreadPoolExecutor(max_workers=1)``, so the GPU
  effectively runs ``N`` sequential B=1 forwards.
- ``diffusers`` runs a single B=N forward and is therefore faster for the
  same workload despite using a slower attention backend.

What changed
- ``vllm_omni_async_server.py``:
  - Factored the shared engine-call body into ``_generate_engine_call``
    (private helper, takes ``num_outputs_per_prompt``).
  - ``generate`` now delegates to it with ``num_outputs_per_prompt=1`` and
    returns the first ``DiffusionOutput`` (fully backwards-compatible).
  - New ``generate_batched`` returns ``list[DiffusionOutput]`` by setting
    ``OmniDiffusionSamplingParams.num_outputs_per_prompt=N``.
  - ``_unpack_batched_result`` slices the batched ``OmniRequestOutput``
    (images list + batched custom-output tensors for latents / log-probs /
    prompt-embeds) into per-sample ``DiffusionOutput`` objects.

- ``scripts/bench_qwen_image_rollout.py`` (new):
  - Apples-to-apples T2I benchmark comparing vLLM-Omni and ``diffusers``
    QwenImagePipeline at 512x512, with knobs lifted from
    ``examples/flowgrpo_trainer/run_qwen_image_ocr.sh``.
  - ``--vllm-omni-mode {concurrent,batched}`` toggles between the existing
    N x B=1 path and the new 1 x B=N path.
  - Logs the attention backend in use for each engine and reports setup /
    warmup / steady-state throughput and peak GPU memory.

Measured on this workstation (Qwen-Image, 512x512, num_inference_steps=50,
rollout_n=16, num_prompts=2 -> 32 images, bf16, FA3 / TORCH_SDPA):

  - vllm-omni concurrent (16 x B=1): 0.403 img/s
  - vllm-omni batched    (1  x B=16): 0.655 img/s  (+62 percent vs concurrent)
  - diffusers (1 x B=16):             0.528 img/s

Pre-commit hooks
- Ran ``pre-commit run --files`` on both touched paths. All hooks pass
  except ``check-naming-conventions``, which fails on the *unchanged* tree
  because the underlying ``grep --exclude-dir=venv`` does not cover ``.venv``
  installed at the repo root; that failure is unrelated to this PR.

AI assistance
- This change was prepared with AI assistance; the submitter has reviewed
  every line and run the benchmark above end-to-end.

Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance benchmarking script for Qwen-Image text-to-image rollout and implements batched generation support in the vLLMOmniHttpServer to allow multiple outputs per prompt in a single forward pass. Feedback identifies a potential division-by-zero error in the benchmarking script's summarization logic when no iterations are executed.

Comment on lines +170 to +171
mean_t = sum(times) / len(times)
med_t = statistics.median(times)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _summarize function is susceptible to a ZeroDivisionError and a statistics.StatisticsError if the times list is empty. This can occur if the user sets --iters 0 via the command line. Adding a guard clause to handle empty results would make the benchmark script more robust.

Suggested change
mean_t = sum(times) / len(times)
med_t = statistics.median(times)
if not times:
return 0.0, 0.0, 0.0, 0.0
mean_t = sum(times) / len(times)
med_t = statistics.median(times)

SamitHuang added a commit that referenced this pull request May 14, 2026
…O agent loop

Adds end-to-end support for running a single B=N transformer forward per
prompt in the vllm-omni diffusion rollout, instead of fanning out N
concurrent B=1 requests that AsyncOmniDiffusion's
ThreadPoolExecutor(max_workers=1) then serializes.

The feature is gated by a single config flag,
``actor_rollout_ref.rollout.enable_batched_diffusion`` (default: True).
When the flag is on and ``rollout.n > 1``, the diffusion agent loop
collapses each group of ``rollout.n`` interleaved rows into one
``generate_batched`` engine request with
``num_outputs_per_prompt = rollout.n``. When off, the legacy per-sample
fan-out path is used unchanged.

Server side
- ``vLLMOmniHttpServer.generate_batched`` (new Ray-callable): submits one
  request with ``OmniDiffusionSamplingParams.num_outputs_per_prompt = N``
  and unpacks the resulting batched ``OmniRequestOutput`` (PIL image list
  + batched custom-output tensors) into ``list[DiffusionOutput]``.
- ``vLLMOmniHttpServer.generate`` is now a thin wrapper around the new
  shared ``_generate_engine_call`` helper with N=1 and returns a single
  ``DiffusionOutput`` exactly as before (fully backwards compatible).

Agent loop side
- ``DiffusionAgentLoopWorker.generate_sequences`` now checks
  ``_can_use_batched_path`` and dispatches to ``_run_agent_loops_batched``
  when the flag is on, ``rollout.n > 1``, and the batch length divides
  cleanly by N (the trainer always ``DataProto.repeat(... interleave=True)``s
  before calling us, so adjacent rows share the same prompt).
- ``_run_agent_loop_batched`` re-uses the registered agent loop for
  tokenization / vision-info extraction, issues one
  ``server.generate_batched.remote`` call, and post-processes each of the
  N returned samples with the existing ``_agent_loop_postprocess`` so
  reward computation / extra-field handling is unchanged.
- ``_server_generate_batched`` is a small helper that goes through
  ``LLMServerClient``'s ``_acquire_server`` / ``_release_server`` pair so
  load balancing still works without modifying upstream verl.

Config / docs
- ``DiffusionRolloutConfig.enable_batched_diffusion: bool = True``
  (mirrored in ``diffusion_rollout.yaml`` and regenerated in
  ``_generated_diffusion_trainer.yaml``).
- ``examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh`` carries an
  inline header explaining the flag plus an explicit
  ``enable_batched_diffusion=True`` line. The two sibling scripts
  (``..._lora_sp2.sh``, ``..._lora_async_reward.sh``) automatically pick
  up the flag from the dataclass default.

Tests
- ``tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py``
  covers:
  - One batched call returns ``num_outputs_per_prompt`` valid per-sample
    ``DiffusionOutput`` records (shapes, pixel range, log-probs).
  - ``generate_batched(num_outputs_per_prompt=1)`` produces the same
    shape as the legacy ``generate``.
  - ``num_outputs_per_prompt < 1`` raises ``ValueError`` before touching
    the engine.

Pre-commit
- ``pre-commit run --files`` passes for every touched path. The
  unrelated ``check-naming-conventions`` failure (its grep excludes
  ``venv`` but not ``.venv`` site-packages) is pre-existing and skipped
  for this commit only.

AI assistance
- This change was prepared with AI assistance; the submitter has
  reviewed every line and verified the new batched path via the
  companion benchmark in PR #2 (vllm-omni batched 0.655 img/s vs
  vllm-omni concurrent 0.403 img/s on Qwen-Image @ 512x512, n=16).

Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
SamitHuang added a commit that referenced this pull request May 14, 2026
…O agent loop

Adds end-to-end support for running a single B=N transformer forward per
prompt in the vllm-omni diffusion rollout, instead of fanning out N
concurrent B=1 requests that AsyncOmniDiffusion's
ThreadPoolExecutor(max_workers=1) then serializes.

The feature is gated by a single config flag,
``actor_rollout_ref.rollout.enable_batched_diffusion`` (default: True).
When the flag is on and ``rollout.n > 1``, the diffusion agent loop
collapses each group of ``rollout.n`` interleaved rows into one
``generate_batched`` engine request with
``num_outputs_per_prompt = rollout.n``. When off, the legacy per-sample
fan-out path is used unchanged.

Server side
- ``vLLMOmniHttpServer.generate_batched`` (new Ray-callable): submits one
  request with ``OmniDiffusionSamplingParams.num_outputs_per_prompt = N``
  and unpacks the resulting batched ``OmniRequestOutput`` (PIL image list
  + batched custom-output tensors) into ``list[DiffusionOutput]``.
- ``vLLMOmniHttpServer.generate`` is now a thin wrapper around the new
  shared ``_generate_engine_call`` helper with N=1 and returns a single
  ``DiffusionOutput`` exactly as before (fully backwards compatible).

Agent loop side
- ``DiffusionAgentLoopWorker.generate_sequences`` now checks
  ``_can_use_batched_path`` and dispatches to ``_run_agent_loops_batched``
  when the flag is on, ``rollout.n > 1``, and the batch length divides
  cleanly by N (the trainer always ``DataProto.repeat(... interleave=True)``s
  before calling us, so adjacent rows share the same prompt).
- ``_run_agent_loop_batched`` re-uses the registered agent loop for
  tokenization / vision-info extraction, issues one
  ``server.generate_batched.remote`` call, and post-processes each of the
  N returned samples with the existing ``_agent_loop_postprocess`` so
  reward computation / extra-field handling is unchanged.
- ``_server_generate_batched`` is a small helper that goes through
  ``LLMServerClient``'s ``_acquire_server`` / ``_release_server`` pair so
  load balancing still works without modifying upstream verl.

Config / docs
- ``DiffusionRolloutConfig.enable_batched_diffusion: bool = True``
  (mirrored in ``diffusion_rollout.yaml`` and regenerated in
  ``_generated_diffusion_trainer.yaml``).
- ``examples/flowgrpo_trainer/run_qwen_image_ocr_lora.sh`` carries an
  inline header explaining the flag plus an explicit
  ``enable_batched_diffusion=True`` line. The two sibling scripts
  (``..._lora_sp2.sh``, ``..._lora_async_reward.sh``) automatically pick
  up the flag from the dataclass default.

Tests
- ``tests/workers/rollout/rollout_vllm/test_vllm_omni_generate_batched.py``
  covers:
  - One batched call returns ``num_outputs_per_prompt`` valid per-sample
    ``DiffusionOutput`` records (shapes, pixel range, log-probs).
  - ``generate_batched(num_outputs_per_prompt=1)`` produces the same
    shape as the legacy ``generate``.
  - ``num_outputs_per_prompt < 1`` raises ``ValueError`` before touching
    the engine.

Pre-commit
- ``pre-commit run --files`` passes for every touched path. The
  unrelated ``check-naming-conventions`` failure (its grep excludes
  ``venv`` but not ``.venv`` site-packages) is pre-existing and skipped
  for this commit only.

AI assistance
- This change was prepared with AI assistance; the submitter has
  reviewed every line and verified the new batched path via the
  companion benchmark in PR #2 (vllm-omni batched 0.655 img/s vs
  vllm-omni concurrent 0.403 img/s on Qwen-Image @ 512x512, n=16).

Co-authored-by: GitHub Copilot
Signed-off-by: samithuang <285365963@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant