Skip to content

[rollout, vllm] feat: Add BAGEL RL rollout support via vLLMOmniHttpServer#5947

Open
timzsu wants to merge 10 commits into
verl-project:mainfrom
timzsu:bagel-pipeline
Open

[rollout, vllm] feat: Add BAGEL RL rollout support via vLLMOmniHttpServer#5947
timzsu wants to merge 10 commits into
verl-project:mainfrom
timzsu:bagel-pipeline

Conversation

@timzsu
Copy link
Copy Markdown

@timzsu timzsu commented Apr 9, 2026

What does this PR do?

Add multi-stage BAGEL (thinker + DiT) RL rollout support to verl's vLLMOmniHttpServer, following the existing Qwen-Image pattern. This enables RL training for BAGEL image generation models through verl's diffusion rollout pipeline.

Depends on vllm-omni features:

Checklist Before Starting

Test

# BAGEL (1 GPU, tiny-random model ~376MB)
BAGEL_STAGE_CONFIG=/path/to/bagel_sharedmemory_ci.yaml \
  pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py -v -s

# Qwen-Image (1 GPU, tiny-random model ~30MB)
pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py -v -s

All tests pass (I adjusted the test to share the same engine in all tests in the module, thereby accelerating the tests by 3-4x):

  • BAGEL (4 tests, ~92s): test_generate, test_generate_with_logprobs, test_generate_concurrent, test_generate_with_lora
  • Qwen-Image (3 tests, ~26s): test_generate, test_generate_with_logprobs, test_generate_concurrent

Both suites run in parallel on separate GPUs.

API and Usage Example

# Config (rollout_cfg.engine_kwargs)
"engine_kwargs": {
    "vllm_omni": {
        "custom_pipeline": "examples.flowgrpo_trainer.vllm_omni.pipeline_bagel.BagelPipelineWithLogProb",
        "stage_configs_path": "/path/to/bagel_sharedmemory_ci.yaml",
    }
}

# Generate with RL artifacts
output = server.generate(
    prompt_ids=token_ids,
    sampling_params={
        "num_inference_steps": 10,
        "noise_level": 0.7,
        "sde_type": "sde",
        "logprobs": True,
    },
    request_id="req_001",
)
# output.log_probs, output.extra_fields["all_latents"], output.extra_fields["all_timesteps"]

# Generate with LoRA
from vllm_omni.lora.request import LoRARequest
output = server.generate(
    prompt_ids=token_ids,
    sampling_params={"num_inference_steps": 10},
    request_id="req_002",
    lora_request=LoRARequest(lora_name="policy", lora_int_id=1, lora_path="/path/to/adapter"),
    lora_scale=1.0,
)

Design & Code Changes

New files:

  • examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py — Custom vllm-omni pipeline for BAGEL RL rollouts. Wraps FlowMatchSDEDiscreteScheduler with _BagelSchedulerAdapter that bridges BAGEL's 4-arg step(v_t, sigma, x_t, dt) to diffusers' 3-arg convention. Computes per-request shifted sigmas matching BAGEL's internal schedule. Returns all_latents, all_log_probs, all_timesteps in custom_output.

  • tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py — E2E tests through verl's rollout server, aligned with Qwen-Image test structure. Uses zhengyuansu/bagel-tiny-random (~376MB) for fast CI.

Modified files:

  • verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py — Multi-stage support:

    • run_server: inject stage_configs_path from engine_kwargs into OmniEngineArgs (single unified code path for both single-stage and multi-stage).
    • generate: multi-stage sampling params (defaults for non-diffusion stages, caller params for diffusion stage), prompt_token_ids + modalities for vLLM input processor, caller-supplied lora_request/lora_scale.
  • examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py — Squeeze batch dim in custom_output (batch=1 in rollout) so both pipelines return unbatched tensors. Server passes through without shape manipulation.

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: E2E test requires 1 GPU + tiny-random model. Test file provided at tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py, configurable via BAGEL_STAGE_CONFIG env var.
  • Once your PR is ready for CI, send a message in the ci-request channel in the verl Slack workspace. (If not accessible, please try the Feishu group (飞书群).)
  • If your PR is related to the recipe submodule, please also update the reference to the submodule commit via git submodule update --remote or cd recipe && git pull origin main.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multi-stage pipelines, specifically for the BAGEL model, within the vLLM-Omni rollout infrastructure. Key changes include the implementation of a custom BAGEL pipeline with an SDE scheduler for RL rollouts, updates to the asynchronous server to handle multi-stage engine configurations, and enhancements to the generation logic to support LoRA adapters and correctly unbatch outputs for different pipeline types. I have no feedback to provide as there were no review comments.

@timzsu
Copy link
Copy Markdown
Author

timzsu commented Apr 9, 2026

@SamitHuang @princepride This PR is related to RFC vllm-project/vllm-omni#1904. Feel free to suggest how to improve.

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Apr 9, 2026

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you all sign our Contributor License Agreement before we can accept your contribution.
1 out of 2 committers have signed the CLA.

✅ timzsu
❌ princepride
You have signed the CLA already but the status is still pending? Let us recheck it.

# Target it specifically to avoid errors on the LLM stage.
if is_multi_stage:
diffusion_stage_id = len(default_params_list) - 1
results = await self.engine.collective_rpc("list_loras", stage_ids=[diffusion_stage_id])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assumes only the diffusion stage uses LoRA in a multi-stage pipeline. If a user trains the LLM/thinker stage with LoRA in BAGEL, this hardcoded behavior will fail to locate or apply the adapter?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now back to direct call to list_loras.

# Single-stage pipelines (Qwen-Image) batch outputs with a leading batch
# dim that should be stripped. Multi-stage (BAGEL) returns un-batched tensors.
def _unbatch(v):
if v is None or is_multi_stage:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Relying on is_multi_stage to determine whether to strip the batch dimension in _unbatch(v) introduces fragility if future single-stage or multi-stage models change their output shapes. Modify this to check the actual tensor dimensions or use an explicit configuration flag for batching behavior rather than implicitly tying it to the stage count?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now we only have two pipelines, I have modified it to require the unbatched shape. The Qwen Image pipeline is adjusted to produce the unbatched shape directly.

# -----------------------------------------------------------------------

async def run_server(self, args: argparse.Namespace):
engine_args = OmniEngineArgs.from_cli_args(args)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is better to unify all engine_args from OmniEngineArgs.from_cli_args

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I have created a PR (vllm-project/vllm-omni#2684) that enables the stage configs in the CLI and thus simplified the VeRL side.

@zhtmike
Copy link
Copy Markdown
Contributor

zhtmike commented Apr 10, 2026

@knlnguyen1802 please take a look

custom_prompt = req.prompts[0] if req.prompts else {}
if isinstance(custom_prompt, dict):
prompt_ids = custom_prompt.get("prompt_ids", prompt_ids)
prompt_ids = custom_prompt.get("prompt_token_ids", prompt_ids)
Copy link
Copy Markdown
Contributor

@knlnguyen1802 knlnguyen1802 Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we do not need to rename this. It can be keep for compatible with qwen-image pipeline and keep old name might easier to debug

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VLLM uses prompt_token_ids as the input key, so if we need to keep this, then the engine should handle both prompt_ids and prompt_token_ids. From my perspective, that would make the logic more "dirty". Do you think we should keep the original name?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to make it consistent. If you want to change it to prompt_tokens_ids. It's better to change all prompt_ids into prompt_token_ids

Copy link
Copy Markdown
Author

@timzsu timzsu Apr 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have systematically renamed all prompt_ids in the vllm rollout path.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knlnguyen1802 Can you please have a look at the current version?

# processor which requires "prompt_token_ids" and "modalities".
# Single-stage (e.g. Qwen-Image) reads "prompt_ids" directly.
if len(default_params_list) > 1:
custom_prompt: OmniCustomPrompt = {"prompt_token_ids": prompt_ids, "modalities": ["image"]}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to unify the prompt_token_id and prompt_id. If decide to change to prompt_token_id, pls change qwenimage pipeline as well :)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have systematically renamed all prompt_ids in the VLLM rollout path :)

@SamitHuang
Copy link
Copy Markdown
Collaborator

test_generate_with_lora is a new test compared to QwenImage, how is the generation precision with lora checked? can we add the precision check in the test as well?

@timzsu
Copy link
Copy Markdown
Author

timzsu commented Apr 13, 2026

test_generate_with_lora is a new test compared to QwenImage, how is the generation precision with lora checked? can we add the precision check in the test as well?

@SamitHuang May I ask what you mean by precision check? Now I tested the LoRA introduces a noticeable perturbation without corrupting the image (the difference in the generated images w/ and w/o LoRA is bounded), which I think is the best I can do with a random LoRA. To have the "precision check", we might need a real LoRA adapter. Any recommendations?

@SamitHuang
Copy link
Copy Markdown
Collaborator

BAGEL (4 tests, ~345s), is 5x of qwenimage, which means it may take 3m44s x 5 on the CI server. (https://github.com/verl-project/verl/actions/runs/23841025230/job/69496889448). can we further reduce this test time?

@SamitHuang
Copy link
Copy Markdown
Collaborator

SamitHuang commented Apr 13, 2026

test_generate_with_lora is a new test compared to QwenImage, how is the generation precision with lora checked? can we add the precision check in the test as well?

@SamitHuang May I ask what you mean by precision check? Now I tested the LoRA introduces a noticeable perturbation without corrupting the image (the difference in the generated images w/ and w/o LoRA is bounded), which I think is the best I can do with a random LoRA. To have the "precision check", we might need a real LoRA adapter. Any recommendations?

can provide a deterministic check with random ckpt for reproduction? where the ground-truth should be obtained using the code that are verified with real lora adapter precision checking?

@timzsu
Copy link
Copy Markdown
Author

timzsu commented Apr 13, 2026

BAGEL (4 tests, ~345s), is 5x of qwenimage, which means it may take 3m44s x 5 on the CI server. (https://github.com/verl-project/verl/actions/runs/23841025230/job/69496889448). can we further reduce this test time?

@SamitHuang Previously, the test started up an engine for each test, and now I changed them to share the same engine across all tests in the module. Both tests are 4x faster than before.

@timzsu
Copy link
Copy Markdown
Author

timzsu commented Apr 13, 2026

test_generate_with_lora is a new test compared to QwenImage, how is the generation precision with lora checked? can we add the precision check in the test as well?

@SamitHuang May I ask what you mean by precision check? Now I tested the LoRA introduces a noticeable perturbation without corrupting the image (the difference in the generated images w/ and w/o LoRA is bounded), which I think is the best I can do with a random LoRA. To have the "precision check", we might need a real LoRA adapter. Any recommendations?

can provide a deterministic check with random ckpt for reproduction? where the ground-truth should be obtained using the code that are verified with real lora adapter precision checking?

I didn't find an open-sourced real LoRA adapter for BAGEL. Do you have one?

timzsu and others added 5 commits April 17, 2026 12:18
…ttpServer

Add multi-stage BAGEL (thinker + DiT) integration to verl's rollout
server, following the existing Qwen-Image pattern. Includes custom
vllm-omni pipeline with SDE scheduler for log-probability recording,
multi-stage engine initialization, per-request LoRA on the diffusion
stage, and E2E tests with synthetic LoRA adapters.

Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
- Unify run_server: single OmniEngineArgs.from_cli_args path, inject
  stage_configs_path and custom_pipeline_args after parsing
- LoRA: use engine.list_loras() for all stages instead of targeting
  diffusion stage specifically
- Remove _unbatch from server: squeeze batch dim in Qwen-Image pipeline
  instead, so both pipelines return unbatched tensors
- Use len(default_params_list) > 1 for multi-stage checks instead of
  separate flag
- Update BAGEL test model path

Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…lout path

Unify naming with vLLM convention across the rollout server, pipelines,
and tests. This eliminates the need for per-pipeline dict key branching.

Signed-off-by: Zhengyuan Su <su.zhengyuan@u.nus.edu>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
timzsu and others added 2 commits April 17, 2026 12:18
Change init_server fixture from function scope to module scope for both
BAGEL and QwenImage rollout tests. The vllm-omni server now starts once
per test module instead of once per test function.

Results (local, 2x RTX 6000 Ada):
- BAGEL (4 tests): 343s → 92s (3.7x faster)
- QwenImage (3 tests): 63s → 25s (2.5x faster)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
princepride and others added 3 commits April 20, 2026 09:29
- Add DiffusionModelBase implementation for Bagel (bagel.py) with custom
  model loading, scheduler, and forward pass
- Add build_module hook in DiffusionModelBase and DiffusersFSDPEngine to
  support non-standard model loading (e.g. Bagel's BagelForTraining)
- Fix scheduler device/precision mismatch in FlowMatchSDEDiscreteScheduler
  (index_for_timestep nearest-neighbor, move sigmas to sample device)
- Fix pipeline_bagel to stack trajectory tensors for proper batching
- Add system prompt and negative prompt to OCR training data
- Add error handling in reward_fn for vLLM reward model failures
- Guard prompt_embeds access for models that don't produce embeddings
- Add profiler support to diffusion trainer (config + start/stop hooks)
- Fix prompt_ids -> prompt_token_ids rename in agent loop server call
- Update run_bagel_flowgrpo.sh with tuned hyperparameters

Co-authored-by: Claude
Co-authored-by: Cursor <cursoragent@cursor.com>
Enable Bagel FlowGRPO LoRA training with vLLM-omni rollout
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants