Skip to content

[diffusion, rollout, trainer] feat: add BAGEL FlowGRPO support#66

Open
timzsu wants to merge 18 commits into
verl-project:mainfrom
timzsu:port-pr5947-bagel
Open

[diffusion, rollout, trainer] feat: add BAGEL FlowGRPO support#66
timzsu wants to merge 18 commits into
verl-project:mainfrom
timzsu:port-pr5947-bagel

Conversation

@timzsu
Copy link
Copy Markdown
Collaborator

@timzsu timzsu commented May 10, 2026

What does this PR do?

Ports BAGEL FlowGRPO support from verl-project/verl#5947 into verl-omni and aligns the integration with the existing Qwen image FlowGRPO path.

This PR adds BAGEL-specific pipeline adapters, rollout tests, example FlowGRPO training config/scripts, and shared diffusion rollout/training plumbing for models that do not use the Qwen image prompt-embedding path.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, vllm_omni, rollout, trainer, ci, training_utils, recipe, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, diffusion, omni, tests, docker
    • If this PR involves multiple modules, separate them with , like [diffusion, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][diffusion, fsdp] feat: new rollout scheduler

Test

GPU rollout tests:

python -m pytest \
  tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py \
  tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py::test_generate

Result:

5 passed, 25 warnings in 43.40s

API and Usage Example

Design & Code Changes

  • Add verl_omni.pipelines.bagel_flow_grpo with BAGEL model loading, rollout, and diffusers training adapters.
  • Register BAGEL pipeline exports alongside the existing Qwen image FlowGRPO pipeline.
  • Add BAGEL FlowGRPO example config, reward function, training script, and local smoke script under examples/flowgrpo_trainer.
  • Extend shared diffusion model/rollout paths so BAGEL can provide model-specific inputs and outputs without breaking Qwen image behavior.
  • Update FSDP diffusers engine input preparation to support optional prompt embeddings while preserving the newer Ulysses sequence-parallel padding path from current main.
  • Add rollout coverage for BAGEL generation, scheduler behavior, LoRA generation, and the existing Qwen generate path.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

  • Read the Contribute Guide.
  • Apply pre-commit checks: pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
  • Add / Update the documentation.
  • Add unit or end-to-end test(s) to the CI workflow to cover all the code. If not feasible, explain why: ...

Some training diagram

Screenshot 2026-05-13 at 8 22 57 PM Screenshot 2026-05-13 at 8 23 09 PM

@timzsu timzsu requested review from SamitHuang and zhtmike as code owners May 10, 2026 07:42
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the BAGEL (Mixture-of-Thought) model within the FlowGRPO training pipeline. Key additions include the BagelForTraining module, corresponding training and rollout adapters, and example configurations and scripts for OCR-based reward training. The PR also integrates a global profiling system into the diffusion trainer and updates various components to handle multi-stage model configurations and renamed prompt parameters. Feedback focuses on improving the efficiency of the Bagel forward pass by avoiding per-sample loops, ensuring correct handling of multiple samples per prompt in trajectory metadata, avoiding hardcoded token IDs, and optimizing network session management in the reward function.

Comment thread verl_omni/pipelines/bagel_flow_grpo/bagel_model.py Outdated
Comment thread verl_omni/pipelines/bagel_flow_grpo/bagel_model.py
Comment thread examples/flowgrpo_trainer/reward_fn.py Outdated
Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your PR!
Looks good with few comments.

Btw can you show the reward curve in your PR descriptions

Comment thread examples/flowgrpo_trainer/reward_fn.py Outdated
Comment thread examples/flowgrpo_trainer/test_bagel_train.py Outdated
Comment thread examples/flowgrpo_trainer/run_bagel_flowgrpo.sh
Comment thread tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py
Comment thread verl_omni/pipelines/bagel_flow_grpo/bagel_model.py
Comment thread verl_omni/trainer/diffusion/ray_diffusion_trainer.py Outdated
Comment thread verl_omni/trainer/diffusion/ray_diffusion_trainer.py Outdated
Comment thread verl_omni/workers/engine/fsdp/diffusers_impl.py Outdated
Comment thread verl_omni/workers/rollout/vllm_rollout/utils.py Outdated
Comment thread verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py
@zhtmike
Copy link
Copy Markdown
Collaborator

zhtmike commented May 10, 2026

@knlnguyen1802 Please take a look of vllm-omni related change. Thanks!

@zhtmike zhtmike requested a review from knlnguyen1802 May 10, 2026 12:11
princepride and others added 10 commits May 11, 2026 09:44
Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
* fsdp/diffusers_impl: extract registry-based custom model loading into
  ``_build_module_from_registry`` helper; ``_build_module`` now simply
  delegates to it and falls back to ``AutoModel`` when no custom loader
  is registered.
* vllm_rollout/utils: drop the ``VERL_OMNI_ENABLE_WORKER_DEATH_SIGNAL``
  env gate and always call ``set_death_signal()`` (restores the original
  upstream behavior).

Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
Signed-off-by: 汪志鹏 <wangzhipeng628@gmail.com>
Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
Signed-off-by: Wang, Zhipeng | RASIA <zhipeng.wang@rakuten.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
- Drop examples/flowgrpo_trainer/reward_fn.py: superseded by
  verl_omni/utils/reward_score/genrm_ocr.py.
- Drop examples/flowgrpo_trainer/test_bagel_train.py: private FSDP +
  CFG smoke test, not maintained as a recipe.
- run_bagel_flowgrpo.sh: point reward_path at the new genrm_ocr.py.
- examples/flowgrpo_trainer/README.md: add a "BAGEL recipe" section
  describing prerequisites, launch command and what differs from the
  Qwen-Image recipe.
- tests/.../test_vllm_omni_bagel_generate.py: collapse the three
  non-LoRA tests (test_generate / test_generate_with_logprobs /
  test_generate_concurrent) into one concurrent SDE+logprobs test.
  The LoRA test is kept separate since it exercises a distinct
  adapter-loading code path.
- workers/engine/fsdp/diffusers_impl.py::_build_module_from_registry:
  add a docstring warning that hooks (attention processors,
  gradient-checkpointing, LoRA, dtype upcast) may be partially
  effective or silently inactive on custom-loaded modules, plus a
  TODO to migrate registered architectures into a first-class
  training engine and drop this escape hatch. Emit a runtime warning
  log line when a custom loader is taken.

Signed-off-by: princepride <wangzhipeng628@gmail.com>
Comment thread verl_omni/workers/engine/fsdp/diffusers_impl.py
Signed-off-by: princepride <wangzhipeng628@gmail.com>
@SamitHuang SamitHuang mentioned this pull request May 12, 2026
27 tasks
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Conflicts:
- examples/flowgrpo_trainer/README.md: kept upstream's new Ulysses-SP and
  full-weight Qwen-Image variant blurbs together with our BAGEL recipe
  section.

Additional fix:
- verl_omni/pipelines/bagel_flow_grpo/diffusers_training_adapter.py:
  ``forward_and_sample_previous_step`` now returns the new 4-tuple
  ``(log_prob, prev_sample_mean, std_dev_t, sqrt_dt)`` to match the
  GRPO-Guard plumbing introduced upstream in verl-project#48 (BAGEL still trains
  with ``loss_mode=flow_grpo`` so ``sqrt_dt`` is unused, but the engine
  layer now unpacks 4-tuples unconditionally).

Co-authored-by: GitHub Copilot
Signed-off-by: princepride <wangzhipeng628@gmail.com>
- examples/flowgrpo_trainer/run_bagel_flowgrpo_local.sh: removed.
  Personal workstation launcher (hard-coded /proj-tango-pvc paths,
  debug-only env vars). Kept locally via .git/info/exclude, the same
  way test_bagel_train.py is handled per reviewer feedback.
- .pre-commit-config.yaml: reverted; adding ``.venv`` to the
  check-naming-conventions grep is unrelated to BAGEL FlowGRPO.
- tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py:
  reverted; switching the Qwen-Image fixture to ``scope='module'`` is
  an unrelated test optimization.

Per AGENTS.md "No low-value busywork PRs" — these mechanical/personal
changes should land in their own PR if needed.

Co-authored-by: GitHub Copilot
Signed-off-by: princepride <wangzhipeng628@gmail.com>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two coupled changes, both required to keep Qwen-Image working alongside BAGEL:

  1. Rename prompt_idsprompt_token_ids. vllm-omni 0.20+'s OmniCustomPrompt standardizes on prompt_token_ids (matching vLLM's TokensPrompt). The server-side patch writes that key now; without the matching rename here, custom_prompt.get("prompt_ids", ...) is None and forward() silently falls into the warmup/dummy branch, returning an empty DiffusionOutput — Qwen-Image rollout would degrade to empty batches without raising.

  2. Move the [0] batch-dim squeeze from server to pipeline. BAGEL is multi-stage and its custom_output isn't shaped [1, ...], so the server can't blindly index [0]. New contract: each pipeline returns per-sample tensors; the server passes them through. Net shape for Qwen-Image consumers is unchanged.

@princepride
Copy link
Copy Markdown
Collaborator

@zhtmike @SamitHuang PTAL

Copy link
Copy Markdown
Collaborator

@zhtmike zhtmike left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, with few small suggestions
And one question for moficatiion on ‎verl_omni/trainer/diffusion/ray_diffusion_trainer.py.


extra_fields["raw_prompt"] = kwargs["raw_prompt"]

# ``return_attention_mask=True`` is required by token-aware adapters (e.g. BAGEL).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ``return_attention_mask=True`` is required by token-aware adapters (e.g. BAGEL).

Comment on lines -382 to +422
def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor:
"""
compute reward use colocate reward model
def _compute_reward_colocate(self, batch: DataProto) -> DataProto:
"""Compute per-sample diffusion reward via the colocated reward loop.

Bypasses ``RewardLoopManager.compute_rm_score`` (LLM-only: assumes
``responses`` has a token axis and reads ``attention_mask``) and
assembles a ``[B, 1]`` ``rm_scores`` tensor directly.
"""
assert self.reward_loop_manager is not None, "RewardLoopManager is None"
batch_reward = self.reward_loop_manager.compute_rm_score(batch)
return batch_reward
manager = self.reward_loop_manager

if manager.reward_model_manager is not None:
manager.reward_model_manager.wake_up()

chunks = batch.chunk(len(manager.reward_loop_workers))
outputs = ray.get(
[
worker.compute_score_batch.remote(chunk)
for worker, chunk in zip(manager.reward_loop_workers, chunks, strict=True)
]
)
outputs_flat = [item for sublist in outputs for item in sublist]

scores = [item["reward_score"] for item in outputs_flat]
rm_scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(-1)
reward_batch = TensorDict({"rm_scores": rm_scores}, batch_size=len(batch))

reward_extra_infos = [output.get("reward_extra_info", {}) for output in outputs_flat]
reward_extra_keys = list(reward_extra_infos[0].keys()) if reward_extra_infos else []
non_tensor_batch = {
key: np.array([info[key] for info in reward_extra_infos]) for key in reward_extra_keys
}

if manager.reward_model_manager is not None:
manager.reward_model_manager.sleep()

return DataProto(
batch=reward_batch,
non_tensor_batch=non_tensor_batch,
meta_info={"reward_extra_keys": reward_extra_keys},
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens here?

(attention processors, gradient checkpointing, LoRA, dtype upcast)
may be silently inactive on the returned module.

TODO: drop this function once the model is integrated into a
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TODO: drop this function once the model is integrated into a
# TODO (princepride): drop this function once the model is integrated into a

def __new__(cls, **kwargs):
set_death_signal()

# Do NOT call verl's ``set_death_signal``: ``PR_SET_PDEATHSIG`` is
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@knlnguyen1802 please take a look

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fixed in vllm-omni main branch

Comment on lines +83 to +85
def _preprocess_engine_kwargs(self, engine_kwargs: dict) -> None:
# No-op: ``deploy_config`` is a vllm-omni CLI flag and must reach the parser.
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _preprocess_engine_kwargs(self, engine_kwargs: dict) -> None:
# No-op: ``deploy_config`` is a vllm-omni CLI flag and must reach the parser.
return

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to report the reference performance in the Performance Reference doc

- Passes the deploy-config YAML to vllm-omni via
`+actor_rollout_ref.rollout.engine_kwargs.vllm_omni.deploy_config`. The
legacy `stage_configs_path` entrypoint is **not** supported: it routes
through vllm-omni 0.20's deprecated stage-args loader, which silently
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update vllm-omni version pin for 0.20 in the installation doc?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update vllm-omni version pin for 0.20 in the installation doc?

Let us do it in separate PR

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#78

Comment on lines +50 to +84
def _to_token_list(token_ids: Any) -> list[int] | None:
if token_ids is None:
return None
if isinstance(token_ids, torch.Tensor):
token_ids = token_ids.detach().cpu().tolist()
if token_ids and isinstance(token_ids[0], list):
token_ids = token_ids[0]
return [int(token_id) for token_id in token_ids]


def _extract_prompt_text(decoded: str) -> str:
if "<|im_start|>" in decoded:
user_chunks = []
for segment in decoded.split("<|im_start|>"):
if not segment.startswith("user"):
continue
content = segment[len("user") :].lstrip("\n")
content = content.split("<|im_end|>", 1)[0]
user_chunks.append(content)
if user_chunks:
decoded = user_chunks[-1]

for marker in _CHAT_MARKERS:
decoded = decoded.replace(marker, "")
return decoded.replace("<|im_start|>", "").replace("<|im_end|>", "").strip()


def _to_cpu_tensor(v):
"""Convert to a single CPU tensor, stacking a list of tensors if needed."""
if isinstance(v, torch.Tensor):
return v.detach().cpu()
if isinstance(v, list):
tensors = [x.detach().cpu() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in v]
return torch.stack(tensors) if tensors else None
return v
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this into a utils.py file

"""

def __new__(cls, **kwargs):
set_death_signal()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not necessary to remove this anymore since the bug is fixed on vllm-omni. If it is for stable run with vllm-omni 0.20.0 please leave it as TODO to add it back later

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants