diff --git a/examples/flowgrpo_trainer/README.md b/examples/flowgrpo_trainer/README.md index 721608ea..a8a07829 100644 --- a/examples/flowgrpo_trainer/README.md +++ b/examples/flowgrpo_trainer/README.md @@ -1,6 +1,9 @@ # FlowGRPO Trainer -This example shows how to post-train `Qwen-Image` with FlowGRPO on an OCR-style image generation task using `vllm-omni` rollout and a visual generative reward model (`Qwen3-VL-8B-Instruct` in this example). +This example shows how to post-train `Qwen-Image` (and, in a separate +recipe, `BAGEL-7B-MoT`) with FlowGRPO on an OCR-style image generation +task using `vllm-omni` rollout and a visual generative reward model +(`Qwen3-VL-8B-Instruct` in this example). For the full installation and quickstart guide, see `docs/start/flowgrpo_quickstart.md`. For algorithm details and rule-based reward training (e.g. JPEG incompressibility), see `docs/algo/flowgrpo.md`. @@ -104,6 +107,61 @@ We have provided a script to enable non-cfg full-weight Qwen-Image OCR training. bash examples/flowgrpo_trainer/run_qwen_image_ocr.sh ``` +## BAGEL recipe + +`run_bagel_flowgrpo.sh` post-trains `BAGEL-7B-MoT` (Mixture-of-Transformers) +with the same OCR reward. BAGEL is registered through the +`verl_omni.pipelines.bagel_flow_grpo` adapter pair as the architecture +`OmniBagelForConditionalGeneration`, and the rollout uses a +single-stage vllm-omni pipeline whose schema is described in +[`bagel_deploy_config.yaml`](bagel_deploy_config.yaml). + +Prerequisites in addition to the Qwen-Image recipe: + +- A local copy of `BAGEL-7B-MoT` (HF repo `ByteDance-Seed/BAGEL-7B-MoT`). +- The same `Qwen3-VL-8B-Instruct` reward model and OCR parquet files + produced above. + +Launch: + +```bash +export BAGEL_MODEL_PATH=/path/to/BAGEL-7B-MoT +export REWARD_MODEL_PATH=/path/to/Qwen3-VL-8B-Instruct +export OCR_TRAIN_PATH=$WORKSPACE/data/ocr/train.parquet +export OCR_TEST_PATH=$WORKSPACE/data/ocr/test.parquet + +bash examples/flowgrpo_trainer/run_bagel_flowgrpo.sh +``` + +Notable differences from the Qwen-Image recipe: + +- Uses `+actor_rollout_ref.model.architecture=OmniBagelForConditionalGeneration` + to bypass the `model_index.json` lookup (BAGEL ships as a single + custom checkpoint, not a `diffusers` pipeline). +- LoRA `target_modules` are the BAGEL MoT generation projections + (`q_proj_moe_gen`, `k_proj_moe_gen`, `v_proj_moe_gen`, + `o_proj_moe_gen`). +- Passes the deploy-config YAML to vllm-omni via + `+actor_rollout_ref.rollout.engine_kwargs.vllm_omni.deploy_config`. The + legacy `stage_configs_path` entrypoint is **not** supported: it routes + through vllm-omni 0.20's deprecated stage-args loader, which silently + kills the BAGEL `DiffusionWorker` subprocess after warmup. Always use + the `deploy_config` schema documented at + [`bagel_deploy_config.yaml`](bagel_deploy_config.yaml). +- Defaults to `trainer.n_gpus_per_node=4` with + `actor_rollout_ref.rollout.tensor_model_parallel_size=1` (4 TP=1 + rollout replicas), matching the Qwen-Image recipe. Be aware of a + TOCTOU race in vllm-omni's per-process `MASTER_PORT` picker + (`OmniDiffusionConfig.__post_init__` →`settle_port` in + [`vllm_omni/diffusion/data.py`](https://github.com/vllm-project/vllm-omni/blob/main/vllm_omni/diffusion/data.py)): + every concurrent `vLLMOmniHttpServer` Ray actor independently calls + `is_port_available(p)` and may pick the same port before any of them + actually `bind`s. Birthday-paradox collision probability is roughly 4% + at 4 actors and 18% at 8 in the default 100-port window, and is + amplified further when retries land inside the prior run's TIME_WAIT + window (≈60s). If a launch dies during `init_distributed_environment` + with `EADDRINUSE` on a port in 30005–30105, wait ~60s and re-launch. + ## Performance diff --git a/examples/flowgrpo_trainer/bagel_deploy_config.yaml b/examples/flowgrpo_trainer/bagel_deploy_config.yaml new file mode 100644 index 00000000..ed557cc0 --- /dev/null +++ b/examples/flowgrpo_trainer/bagel_deploy_config.yaml @@ -0,0 +1,25 @@ +# Single-stage BAGEL deploy config for FlowGRPO training with colocated workers. +# +# Uses vllm-omni 0.20+'s ``--deploy-config`` schema (``pipeline`` topology +# marker + flat ``stages`` list). The legacy ``--stage-configs-path`` schema +# (``stage_args`` + ``runtime`` block) silently kills the BAGEL +# ``DiffusionWorker`` after warmup on vllm-omni 0.20, so we don't use it. +# +# Mirrors vllm-omni's reference single-stage BAGEL config at +# ``vllm_omni/deploy/bagel_single_stage.yaml``: the DiT stage owns the full +# LLM (Qwen2-MoT), ViT, VAE, and tokenizer, so a single stage covers all +# four modalities (text2img, img2img, img2text, text2text) plus think mode. + +pipeline: bagel_single_stage +async_chunk: false + +stages: + - stage_id: 0 + max_num_batched_tokens: 32768 + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + enable_prefix_caching: false + devices: "0" + default_sampling_params: + seed: 52 diff --git a/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh b/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh new file mode 100644 index 00000000..89e369a1 --- /dev/null +++ b/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh @@ -0,0 +1,117 @@ +# Bagel LoRA RL, vllm_omni rollout (FlowGRPO) +# +# Prerequisites: +# 1. A Bagel model (e.g. BAGEL-7B-MoT) at $BAGEL_MODEL_PATH +# 2. A vllm-omni deploy-config YAML at $BAGEL_DEPLOY_CONFIG (we ship one +# next to this script at ``bagel_deploy_config.yaml``) +# 3. ``BagelDiffusion`` registered as ``OmniBagelForConditionalGeneration`` +# via ``verl_omni.pipelines.bagel_flow_grpo`` (auto-imported) +# 4. A reward VLM model at $REWARD_MODEL_PATH +# 5. OCR training data at $OCR_TRAIN_PATH / $OCR_TEST_PATH +# (generate via: ``python examples/flowgrpo_trainer/data_process/qwenimage_ocr.py``) +# +# Usage: +# export BAGEL_MODEL_PATH=/path/to/BAGEL-7B-MoT +# export REWARD_MODEL_PATH=/path/to/Qwen3-VL-8B-Instruct +# bash examples/flowgrpo_trainer/run_bagel_flowgrpo.sh +# +# # Override any param via CLI: +# bash examples/flowgrpo_trainer/run_bagel_flowgrpo.sh trainer.n_gpus_per_node=8 +# +# Default uses 4 GPUs with ``tensor_model_parallel_size=1`` (4 single-GPU +# rollout replicas) to mirror the Qwen-Image recipe. Be aware of a TOCTOU +# race in vllm-omni's per-actor ``MASTER_PORT`` picker (``settle_port`` in +# ``vllm_omni/diffusion/data.py``): every concurrent ``vLLMOmniHttpServer`` +# Ray actor independently calls ``is_port_available(p)`` and may pick the +# same port before any of them ``bind()``s, with collision probability +# scaling by the number of concurrent actors (~4% at 4, ~18% at 8 in the +# default 100-port window) and amplified further when retries land inside +# the prior run's TIME_WAIT window (≈60s). If a launch dies during +# ``init_distributed_environment`` with ``EADDRINUSE`` on a port in +# 30005-30105, wait ~60s and re-launch; the upstream bug is tracked at +# vllm-project/vllm-omni#TBD. + +set -x + +# --------------- Paths (override via environment) --------------- +BAGEL_MODEL_PATH=${BAGEL_MODEL_PATH:-$HOME/models/BAGEL-7B-MoT} +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BAGEL_DEPLOY_CONFIG=${BAGEL_DEPLOY_CONFIG:-$SCRIPT_DIR/bagel_deploy_config.yaml} + +REWARD_MODEL_PATH=${REWARD_MODEL_PATH:-$HOME/models/Qwen3-VL-8B-Instruct} + +ocr_train_path=${OCR_TRAIN_PATH:-$HOME/data/ocr/train.parquet} +ocr_test_path=${OCR_TEST_PATH:-$HOME/data/ocr/test.parquet} + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=verl_omni/utils/reward_score/genrm_ocr.py + +python3 -m verl_omni.trainer.diffusion.main_flowgrpo \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=16 \ + data.max_prompt_length=256 \ + data.trust_remote_code=True \ + actor_rollout_ref.model.path=$BAGEL_MODEL_PATH \ + actor_rollout_ref.model.tokenizer_path=$BAGEL_MODEL_PATH \ + +actor_rollout_ref.model.architecture=OmniBagelForConditionalGeneration \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.model.pipeline.height=512 \ + actor_rollout_ref.model.pipeline.width=512 \ + actor_rollout_ref.model.pipeline.num_inference_steps=15 \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['q_proj_moe_gen','k_proj_moe_gen','v_proj_moe_gen','o_proj_moe_gen','mlp_moe_gen.gate_proj','mlp_moe_gen.up_proj','mlp_moe_gen.down_proj']" \ + actor_rollout_ref.actor.optim.lr=1e-4 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.actor.shuffle=False \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.diffusion_loss.loss_mode=flow_grpo \ + actor_rollout_ref.actor.diffusion_loss.clip_ratio=1e-5 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.agent.num_workers=2 \ + actor_rollout_ref.rollout.load_format=auto \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.pipeline.num_inference_steps=15 \ + actor_rollout_ref.rollout.pipeline.max_sequence_length=256 \ + actor_rollout_ref.rollout.algo.noise_level=1.3 \ + actor_rollout_ref.rollout.algo.sde_type="sde" \ + actor_rollout_ref.rollout.algo.sde_window_size=2 \ + actor_rollout_ref.rollout.algo.sde_window_range="[0,7]" \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.rollout.val_kwargs.pipeline.num_inference_steps=15 \ + actor_rollout_ref.rollout.val_kwargs.algo.noise_level=0.0 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.deploy_config=$BAGEL_DEPLOY_CONFIG \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + reward.num_workers=1 \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$REWARD_MODEL_PATH \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=4 \ + +reward.reward_model.rollout.engine_kwargs.vllm.mm_processor_cache_gb=0 \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + algorithm.global_std=False \ + algorithm.bypass_mode=False \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=bagel_ocr_lora_orig_replica \ + trainer.log_val_generations=4 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=4 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.total_training_steps=300 "$@" diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py b/tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py new file mode 100644 index 00000000..ec0c2639 --- /dev/null +++ b/tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py @@ -0,0 +1,300 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +E2E test for BAGEL RL pipeline via vLLMOmniHttpServer. + +Uses verl's rollout server with BAGEL's multi-stage pipeline +(thinker on GPU 0, DiT on GPU 1) and BagelPipelineWithLogProb. + +Usage: + pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py -v -s +""" + +import json +import os +import tempfile +from pathlib import Path +from uuid import uuid4 + +import numpy as np +import pytest +import ray +import torch +from omegaconf import OmegaConf +from safetensors.torch import save_file +from verl.workers.rollout.replica import RolloutMode + +from verl_omni.workers.rollout.replica import DiffusionOutput +from verl_omni.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniHttpServer + +MODEL_PATH = Path(os.path.expanduser("~/models/tiny-random/bagel")) +DEFAULT_DEPLOY_CONFIG = Path(__file__).resolve().parents[4] / "examples/flowgrpo_trainer/bagel_deploy_config.yaml" +DEPLOY_CONFIG = Path(os.environ.get("BAGEL_DEPLOY_CONFIG", DEFAULT_DEPLOY_CONFIG)) + +DEFAULT_PROMPT = ( + "a beautiful sunset over the ocean with vibrant orange and purple clouds reflecting on the calm water surface" +) + + +# --------------------------------------------------------------------- +# 👇 Test Helper Functions & Fixtures 👇 +# --------------------------------------------------------------------- + + +def _tokenize_prompt(text: str) -> list[int]: + """Tokenize a text prompt into token IDs for BAGEL.""" + from transformers import AutoTokenizer + from verl.utils.tokenizer import normalize_token_ids + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + token_ids = normalize_token_ids(tokenizer.encode(text)) + return token_ids + + +@pytest.fixture(scope="module") +def init_server(): + """Create and launch a vLLMOmniHttpServer Ray actor with BAGEL.""" + if not DEPLOY_CONFIG.exists(): + pytest.skip(f"BAGEL deploy config not found: {DEPLOY_CONFIG}") + + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + } + }, + ignore_reinit_error=True, + ) + + rollout_cfg = OmegaConf.create( + { + "_target_": "verl_omni.workers.config.diffusion.DiffusionRolloutConfig", + "name": "vllm_omni", + "mode": "async", + "tensor_model_parallel_size": 1, + "data_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "gpu_memory_utilization": 0.9, + "max_num_batched_tokens": 32768, + "max_num_seqs": 1, + "max_model_len": 32768, + "dtype": "bfloat16", + "load_format": "auto", + "enforce_eager": True, + "enable_chunked_prefill": False, + "enable_prefix_caching": False, + "enable_sleep_mode": False, + "free_cache_engine": True, + "disable_log_stats": True, + "n": 1, + "pipeline": { + "_target_": "verl_omni.workers.config.diffusion.rollout.DiffusionPipelineConfig", + "num_inference_steps": 10, + }, + "engine_kwargs": { + "vllm_omni": { + "deploy_config": str(DEPLOY_CONFIG), + } + }, + } + ) + + model_cfg = OmegaConf.create( + { + "_target_": "verl_omni.workers.config.diffusion.DiffusionModelConfig", + "path": MODEL_PATH, + "architecture": "OmniBagelForConditionalGeneration", + "trust_remote_code": True, + "load_tokenizer": False, + } + ) + + ServerCls = ray.remote(vLLMOmniHttpServer) + server = ServerCls.options( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + "NCCL_CUMEM_ENABLE": "0", + } + }, + max_concurrency=10, + ).remote( + config=rollout_cfg, + model_config=model_cfg, + rollout_mode=RolloutMode.STANDALONE, + workers=[], + replica_rank=0, + node_rank=0, + gpus_per_node=2, + nnodes=1, + cuda_visible_devices="0,1", + ) + + ray.get(server.launch_server.remote()) + + yield server + + ray.shutdown() + + +# --------------------------------------------------------------------- +# 👇 Tests 👇 +# --------------------------------------------------------------------- + + +def test_generate(init_server): + """Concurrent BAGEL generations with SDE log_probs return valid DiffusionOutput. + + Single combined check covering: image shape & pixel range, SDE log_probs + + RL artifacts (``all_latents`` / ``all_timesteps``), and concurrent + dispatch through the rollout server. + """ + server = init_server + + prompts = [ + "a beautiful sunset over the ocean with vibrant orange and purple clouds " + "reflecting on the calm water surface near a rocky coastline", + "a fluffy orange cat sitting on a wooden windowsill looking outside at " + "a garden full of colorful flowers on a bright sunny afternoon", + "a majestic mountain landscape covered with fresh white snow under a " + "clear blue sky with pine trees in the foreground and a frozen lake", + "a futuristic city at night with neon lights glowing on tall glass " + "skyscrapers and flying vehicles soaring between the buildings", + ] + + refs = [ + server.generate.remote( + prompt_ids=_tokenize_prompt(prompt), + sampling_params={ + "num_inference_steps": 10, + "noise_level": 0.7, + "sde_type": "sde", + "logprobs": True, + }, + request_id=f"concurrent_{i}_{uuid4().hex[:8]}", + ) + for i, prompt in enumerate(prompts) + ] + + results = ray.get(refs, timeout=600) + + for i, output in enumerate(results): + assert isinstance(output, DiffusionOutput), f"req {i}: expected DiffusionOutput" + assert len(output.diffusion_output) == 3, f"req {i}: expected 3 channels (CHW)" + h, w = len(output.diffusion_output[0]), len(output.diffusion_output[0][0]) + assert h > 0 and w > 0, f"req {i}: empty image {h}x{w}" + assert 0.0 <= output.diffusion_output[0][0][0] <= 1.0, f"req {i}: pixel out of [0,1]" + assert output.stop_reason in ("completed", "aborted", None) + + assert output.log_probs is not None, f"req {i}: log_probs missing under logprobs=True" + extra = output.extra_fields + assert extra.get("all_latents") is not None, f"req {i}: all_latents missing" + assert extra.get("all_timesteps") is not None, f"req {i}: all_timesteps missing" + + print(f"All {len(results)} concurrent SDE+logprobs requests returned valid DiffusionOutput") + + +# --------------------------------------------------------------------- +# 👇 LoRA helpers 👇 +# --------------------------------------------------------------------- + +# Tiny BAGEL: hidden_size=64, 2 Q heads, 2 KV heads, head_dim=32 +# QKV packed dim = (2+2+2)*32 = 192 +_LORA_DIM = 64 +_LORA_QKV_DIM = 192 +_LORA_MODULE = "bagel.language_model.model.layers.0.self_attn.qkv_proj" +_LORA_RANK = 4 + + +def _make_synthetic_lora(adapter_dir: Path): + """Create a synthetic rank-4 LoRA adapter on disk.""" + adapter_dir.mkdir(parents=True, exist_ok=True) + gen = torch.Generator().manual_seed(42) + lora_a = torch.randn((_LORA_RANK, _LORA_DIM), dtype=torch.float32, generator=gen) * 0.1 + lora_b = torch.randn((_LORA_QKV_DIM, _LORA_RANK), dtype=torch.float32, generator=gen) * 0.5 + save_file( + { + f"base_model.model.{_LORA_MODULE}.lora_A.weight": lora_a, + f"base_model.model.{_LORA_MODULE}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps({"r": _LORA_RANK, "lora_alpha": _LORA_RANK, "target_modules": [_LORA_MODULE]}), + encoding="utf-8", + ) + return str(adapter_dir) + + +def test_generate_with_lora(init_server): + """LoRA adapter changes output and deactivation restores baseline.""" + from vllm_omni.lora.request import LoRARequest + + server = init_server + + with tempfile.TemporaryDirectory() as tmp_dir: + lora_path = _make_synthetic_lora(Path(tmp_dir) / "bagel_lora") + lora_request = LoRARequest(lora_name="test_lora", lora_int_id=42, lora_path=lora_path) + + # 1) Baseline (no LoRA) + baseline = ray.get( + server.generate.remote( + prompt_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={"num_inference_steps": 10}, + request_id=f"lora_base_{uuid4().hex[:8]}", + ), + timeout=300, + ) + + # 2) With LoRA + with_lora = ray.get( + server.generate.remote( + prompt_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={"num_inference_steps": 10}, + request_id=f"lora_on_{uuid4().hex[:8]}", + lora_request=lora_request, + lora_scale=1.0, + ), + timeout=300, + ) + + # 3) Deactivated (no LoRA again) + restored = ray.get( + server.generate.remote( + prompt_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={"num_inference_steps": 10}, + request_id=f"lora_off_{uuid4().hex[:8]}", + ), + timeout=300, + ) + + assert isinstance(baseline, DiffusionOutput) + assert isinstance(with_lora, DiffusionOutput) + assert isinstance(restored, DiffusionOutput) + + base_arr = np.array(baseline.diffusion_output) + lora_arr = np.array(with_lora.diffusion_output) + + diff_lora = np.abs(base_arr - lora_arr).mean() + + print(f"LoRA diff from baseline: {diff_lora:.4f}") + + # LoRA should visibly change output + assert diff_lora > 0.001, f"LoRA had no effect: diff={diff_lora}" + # Output is not corrupted + assert diff_lora < 80, f"LoRA output looks corrupted: diff={diff_lora}" diff --git a/verl_omni/agent_loop/diffusion_agent_loop.py b/verl_omni/agent_loop/diffusion_agent_loop.py index b4538f17..28179d79 100644 --- a/verl_omni/agent_loop/diffusion_agent_loop.py +++ b/verl_omni/agent_loop/diffusion_agent_loop.py @@ -223,15 +223,17 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA extra_fields["raw_prompt"] = kwargs["raw_prompt"] + # ``return_attention_mask=True`` is required by token-aware adapters (e.g. BAGEL). prompt_output = self.tokenizer.pad( {"input_ids": output.prompt_ids}, padding="max_length", max_length=self.rollout_config.prompt_length, return_tensors="pt", - return_attention_mask=False, + return_attention_mask=True, ) if prompt_output["input_ids"].dim() == 1: prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) + prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) response_diffusion_output = output.response_diffusion_output.unsqueeze(0) @@ -240,6 +242,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA response_logprobs = output.response_logprobs.unsqueeze(0) prompt_ids = prompt_output["input_ids"] + extra_fields["attention_mask"] = prompt_output["attention_mask"] await self._compute_score( output, diff --git a/verl_omni/pipelines/__init__.py b/verl_omni/pipelines/__init__.py index d0755463..26857d8c 100644 --- a/verl_omni/pipelines/__init__.py +++ b/verl_omni/pipelines/__init__.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import _patch # noqa: F401 — apply Ulysses mask fix +from . import ( + _patch, # noqa: F401 — apply Ulysses mask fix + bagel_flow_grpo, + qwen_image_flow_grpo, +) +from .bagel_flow_grpo import * # noqa: F401, F403 from .qwen_image_flow_grpo import * # noqa: F401, F403 from .qwen_image_mix_grpo import * # noqa: F401, F403 __all__ = list(qwen_image_flow_grpo.__all__) __all__ += list(qwen_image_mix_grpo.__all__) +__all__ += list(bagel_flow_grpo.__all__) \ No newline at end of file diff --git a/verl_omni/pipelines/bagel_flow_grpo/__init__.py b/verl_omni/pipelines/bagel_flow_grpo/__init__.py new file mode 100644 index 00000000..f901c94d --- /dev/null +++ b/verl_omni/pipelines/bagel_flow_grpo/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .diffusers_training_adapter import BagelDiffusion +from .vllm_omni_rollout_adapter import BagelPipelineWithLogProb + +__all__ = ["BagelDiffusion", "BagelPipelineWithLogProb"] diff --git a/verl_omni/pipelines/bagel_flow_grpo/bagel_model.py b/verl_omni/pipelines/bagel_flow_grpo/bagel_model.py new file mode 100644 index 00000000..ac4ba966 --- /dev/null +++ b/verl_omni/pipelines/bagel_flow_grpo/bagel_model.py @@ -0,0 +1,986 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BagelForTraining – FSDP-compatible BAGEL MoT module for flow-matching training. + +Ported from vllm-omni/BAGEL with the following correctness-critical details: + * MoT (Mixture-of-Thought): dual pathways for text vs generation tokens + * start_of_image / end_of_image boundary tokens are required + * All latent tokens share ONE RoPE position (spatial via 2-D sincos embed) + * QK-norm + RoPE in float32; cast to bfloat16 only for SDPA + * Attention mask: text-context is causal & cannot see image region + +Dependencies: torch, numpy, safetensors, einops, transformers (AutoTokenizer) +NO dependency on vllm or vllm-omni. +""" + +from __future__ import annotations + +import json +import math +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor + +# =================================================================== +# Config +# =================================================================== + + +@dataclass +class BagelTrainingConfig: + hidden_size: int = 3584 + intermediate_size: int = 18944 + num_hidden_layers: int = 28 + num_attention_heads: int = 28 + num_key_value_heads: int = 4 + vocab_size: int = 152064 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1_000_000.0 + max_position_embeddings: int = 32768 + # Bagel-specific + latent_patch_size: int = 2 + max_latent_size: int = 32 + latent_channel: int = 16 + vae_downsample: int = 8 + start_of_image_id: int = 151652 # <|vision_start|> + end_of_image_id: int = 151653 # <|vision_end|> + + @property + def head_dim(self) -> int: + return self.hidden_size // self.num_attention_heads + + @property + def patch_latent_dim(self) -> int: + return self.latent_patch_size**2 * self.latent_channel + + def save_pretrained(self, save_directory: str): + """Save config as JSON (compatible with diffusers checkpoint manager).""" + from dataclasses import asdict + + output_path = os.path.join(save_directory, "config.json") + os.makedirs(save_directory, exist_ok=True) + with open(output_path, "w") as f: + json.dump(asdict(self), f, indent=4, sort_keys=True) + + @classmethod + def from_model_path(cls, model_path: str) -> BagelTrainingConfig: + cfg_path = os.path.join(model_path, "config.json") + with open(cfg_path) as f: + root_cfg = json.load(f) + llm = root_cfg.get("llm_config", {}) + vae = root_cfg.get("vae_config", {}) + return cls( + hidden_size=llm.get("hidden_size", 3584), + intermediate_size=llm.get("intermediate_size", 18944), + num_hidden_layers=llm.get("num_hidden_layers", 28), + num_attention_heads=llm.get("num_attention_heads", 28), + num_key_value_heads=llm.get("num_key_value_heads", 4), + vocab_size=llm.get("vocab_size", 152064), + rms_norm_eps=llm.get("rms_norm_eps", 1e-6), + rope_theta=llm.get("rope_theta", 1_000_000.0), + max_position_embeddings=llm.get("max_position_embeddings", 32768), + latent_patch_size=root_cfg.get("latent_patch_size", 2), + max_latent_size=root_cfg.get("max_latent_size", 32), + latent_channel=vae.get("z_channels", 16), + vae_downsample=vae.get("downsample", 8), + ) + + +# =================================================================== +# VAE AutoEncoder (from FLUX / BAGEL, Apache-2.0) +# =================================================================== + + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + downsample: int = 8 + ch: int = 128 + out_ch: int = 3 + ch_mult: list[int] | tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + z_channels: int = 16 + scale_factor: float = 0.3611 + shift_factor: float = 0.1159 + + +def _swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def forward(self, x: Tensor) -> Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = F.scaled_dot_product_attention(q, k, v) + h_ = rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + return x + self.proj_out(h_) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: Tensor) -> Tensor: + h = self.norm1(x) + h = _swish(h) + h = self.conv1(h) + h = self.norm2(h) + h = _swish(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + h + + +class _Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor) -> Tensor: + x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0) + return self.conv(x) + + +class _Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + return self.conv(x) + + +class Encoder(nn.Module): + def __init__( + self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + block_in = ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = _Downsample(block_in) + self.down.append(down) + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = self.norm_out(h) + h = _swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = _Upsample(block_in) + self.up.insert(0, up) + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + h = self.norm_out(h) + h = _swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=list(params.ch_mult), + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=list(params.ch_mult), + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + return self.scale_factor * (z - self.shift_factor) + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +def load_ae(path: str) -> tuple[AutoEncoder, AutoEncoderParams]: + """Load VAE autoencoder from a safetensors checkpoint.""" + params = AutoEncoderParams() + ae = AutoEncoder(params) + if path is not None: + from safetensors.torch import load_file + + sd = load_file(path) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + if missing: + print(f"VAE load: {len(missing)} missing keys") + if unexpected: + print(f"VAE load: {len(unexpected)} unexpected keys") + return ae, params + + +# =================================================================== +# Tokenizer & data utilities (replaces BAGEL/data/data_utils.py) +# =================================================================== + + +def load_tokenizer(model_path: str): + """Load tokenizer with special tokens for BAGEL using transformers.AutoTokenizer.""" + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + all_special = set() + for v in tokenizer.special_tokens_map.values(): + if isinstance(v, str): + all_special.add(v) + elif isinstance(v, list): + all_special.update(v) + + new_tokens = [] + for t in ["<|im_start|>", "<|im_end|>", "<|vision_start|>", "<|vision_end|>"]: + if t not in all_special and t not in tokenizer.get_vocab(): + new_tokens.append(t) + if new_tokens: + tokenizer.add_tokens(new_tokens) + + new_token_ids = { + "bos_token_id": tokenizer.convert_tokens_to_ids("<|im_start|>"), + "eos_token_id": tokenizer.convert_tokens_to_ids("<|im_end|>"), + "start_of_image": tokenizer.convert_tokens_to_ids("<|vision_start|>"), + "end_of_image": tokenizer.convert_tokens_to_ids("<|vision_end|>"), + } + return tokenizer, new_token_ids + + +def get_flattened_position_ids(img_h: int, img_w: int, patch_size: int, max_num_patches_per_side: int) -> torch.Tensor: + """Compute flattened 2-D position IDs for latent patches (extrapolate mode).""" + num_patches_h = img_h // patch_size + num_patches_w = img_w // patch_size + coords_h = torch.arange(0, num_patches_h) + coords_w = torch.arange(0, num_patches_w) + pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() + return pos_ids + + +# =================================================================== +# Transformer building blocks +# =================================================================== + + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + input_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return self.weight * x.to(input_dtype) + + +class BagelMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +# =================================================================== +# RoPE helpers +# =================================================================== + + +def _rotate_half(x: Tensor) -> Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_emb(q, k, cos, sin): + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, max_position_embeddings: int = 32768, theta: float = 1_000_000.0): + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, position_ids: Tensor): + freqs = torch.einsum("bi,j->bij", position_ids.float(), self.inv_freq.to(position_ids.device)) + emb = torch.cat([freqs, freqs], dim=-1) + return emb.cos(), emb.sin() + + +# =================================================================== +# MoT Attention & Layer +# =================================================================== + + +class BagelMoTAttention(nn.Module): + """MoT attention with separate standard and generation projections.""" + + def __init__(self, config: BagelTrainingConfig): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: Tensor, + cos: Tensor, + sin: Tensor, + text_mask: Tensor, + latent_mask: Tensor, + L_ctx: int = 0, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + B, L, _ = hidden_states.shape + text_idx = text_mask.nonzero(as_tuple=True) + latent_idx = latent_mask.nonzero(as_tuple=True) + + q = hidden_states.new_zeros(B, L, self.num_heads * self.head_dim) + k = hidden_states.new_zeros(B, L, self.num_kv_heads * self.head_dim) + v = hidden_states.new_zeros(B, L, self.num_kv_heads * self.head_dim) + + text_hs = hidden_states[text_idx] + q[text_idx] = self.q_proj(text_hs) + k[text_idx] = self.k_proj(text_hs) + v[text_idx] = self.v_proj(text_hs) + + latent_hs = hidden_states[latent_idx] + q[latent_idx] = self.q_proj_moe_gen(latent_hs) + k[latent_idx] = self.k_proj_moe_gen(latent_hs) + v[latent_idx] = self.v_proj_moe_gen(latent_hs) + + q = q.view(B, L, self.num_heads, self.head_dim) + k = k.view(B, L, self.num_kv_heads, self.head_dim) + v = v.view(B, L, self.num_kv_heads, self.head_dim) + + q = q.to(torch.float32) + k = k.to(torch.float32) + q_normed = q.new_zeros(q.shape) + k_normed = k.new_zeros(k.shape) + q_normed[text_idx] = self.q_norm(q[text_idx]) + k_normed[text_idx] = self.k_norm(k[text_idx]) + q_normed[latent_idx] = self.q_norm_moe_gen(q[latent_idx]) + k_normed[latent_idx] = self.k_norm_moe_gen(k[latent_idx]) + + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + q_normed, k_normed = _apply_rotary_emb(q_normed, k_normed, cos, sin) + + q_normed = q_normed.to(torch.bfloat16) + k_normed = k_normed.to(torch.bfloat16) + v = v.to(torch.bfloat16) + + if self.num_kv_heads < self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_normed = k_normed.unsqueeze(3).expand(-1, -1, -1, rep, -1).reshape(B, L, self.num_heads, self.head_dim) + v = v.unsqueeze(3).expand(-1, -1, -1, rep, -1).reshape(B, L, self.num_heads, self.head_dim) + + # Split attention. The original (vllm-omni) BAGEL implementation packs + # ``[text, soi, latent, eoi]`` for each request into a single 1-D + # ``flash_attn_varlen`` segment with no padding, so all key positions + # are valid by construction. Here we run *padded* batched SDPA: when + # prompts within a micro-batch have different lengths, samples are + # right-padded with ``token_id=0`` up to ``max_text_len``. Image + # queries attending to those padding positions inject batch-grouping + # noise into the velocity prediction (the same sample produces + # different ``log_prob`` depending on which other prompts share its + # micro-batch), which under the FlowGRPO importance-sampling ratio + # shows up as a non-trivial ``ratio_std`` even at PPO step 1 with + # ``ppo_epochs=1``. Mask the padding key columns to recover the + # ``log_prob`` invariance that the rollout side already has. + # + # Note: text-causal block (rows ``:L_ctx``) does not need a key mask + # because real-text rows only attend to earlier real-text columns + # (padding is right-aligned). Padding *rows* still produce contaminated + # outputs, but those positions are never read out as latent velocity. + q_normed = q_normed.transpose(1, 2) # (B, H, L, D) + k_normed = k_normed.transpose(1, 2) + v = v.transpose(1, 2) + + if L_ctx > 0: + text_out = F.scaled_dot_product_attention( + q_normed[:, :, :L_ctx], + k_normed[:, :, :L_ctx], + v[:, :, :L_ctx], + is_causal=True, + ) + if key_padding_mask is not None and not key_padding_mask.all(): + # ``key_padding_mask`` is True at valid (non-padding) keys. + # SDPA wants an additive/bool mask broadcastable to + # ``(B, H, L_q, L_k)``. We use ``(B, 1, 1, L)``. + img_attn_mask = key_padding_mask.view(B, 1, 1, L) + img_out = F.scaled_dot_product_attention( + q_normed[:, :, L_ctx:], + k_normed, + v, + attn_mask=img_attn_mask, + is_causal=False, + ) + else: + img_out = F.scaled_dot_product_attention( + q_normed[:, :, L_ctx:], + k_normed, + v, + is_causal=False, + ) + attn_out = torch.cat([text_out, img_out], dim=2) + else: + attn_out = F.scaled_dot_product_attention( + q_normed, + k_normed, + v, + is_causal=False, + ) + + attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, -1) + + out = hidden_states.new_zeros(B, L, self.hidden_size) + out[text_idx] = self.o_proj(attn_out[text_idx].to(self.o_proj.weight.dtype)) + out[latent_idx] = self.o_proj_moe_gen(attn_out[latent_idx].to(self.o_proj_moe_gen.weight.dtype)) + return out + + +class BagelMoTLayer(nn.Module): + def __init__(self, config: BagelTrainingConfig): + super().__init__() + self.self_attn = BagelMoTAttention(config) + self.mlp = BagelMLP(config.hidden_size, config.intermediate_size) + self.mlp_moe_gen = BagelMLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: Tensor, + cos: Tensor, + sin: Tensor, + text_mask: Tensor, + latent_mask: Tensor, + L_ctx: int = 0, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + text_idx = text_mask.nonzero(as_tuple=True) + latent_idx = latent_mask.nonzero(as_tuple=True) + + normed = hidden_states.new_zeros(hidden_states.shape) + normed[text_idx] = self.input_layernorm(hidden_states[text_idx]) + normed[latent_idx] = self.input_layernorm_moe_gen(hidden_states[latent_idx]) + + attn_out = self.self_attn( + normed, + cos, + sin, + text_mask, + latent_mask, + L_ctx, + key_padding_mask=key_padding_mask, + ) + hidden_states = hidden_states + attn_out + + residual = hidden_states + mlp_out = hidden_states.new_zeros(hidden_states.shape) + mlp_out[text_idx] = self.mlp(self.post_attention_layernorm(hidden_states[text_idx])) + mlp_out[latent_idx] = self.mlp_moe_gen(self.post_attention_layernorm_moe_gen(hidden_states[latent_idx])) + hidden_states = residual + mlp_out + return hidden_states + + +# =================================================================== +# Position embedding helpers +# =================================================================== + + +def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + return np.concatenate([np.sin(out), np.cos(out)], axis=1) + + +def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int) -> np.ndarray: + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape(2, 1, grid_size, grid_size) + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, freq_dim: int = 256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_dim, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.freq_dim = freq_dim + + def forward(self, t: Tensor) -> Tensor: + half = self.freq_dim // 2 + freqs = torch.exp(-math.log(10000) * torch.arange(half, dtype=torch.float32, device=t.device) / half) + args = t[:, None].float() * freqs[None] + emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + emb = emb.to(self.mlp[0].weight.dtype) + return self.mlp(emb) + + +class PositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side: int, hidden_size: int): + super().__init__() + pos_embed = _get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side) + self.pos_embed = nn.Parameter(torch.from_numpy(pos_embed).float(), requires_grad=False) + + def forward(self, position_ids: Tensor) -> Tensor: + return self.pos_embed[position_ids] + + +# =================================================================== +# Main module: BagelForTraining +# =================================================================== + + +class BagelForTraining(nn.Module): + """Standalone Bagel MoT module for FlowGRPO FSDP training. + + Forward signature: + hidden_states: (B, L_latent, patch_latent_dim) — noisy latent patches + timestep: (B,) — diffusion timestep scalars + text_token_ids: (B, L_text) — tokenized prompt IDs (with bos/eos) + latent_pos_ids: (B, L_latent) — 2-D position indices for latent patches + """ + + def __init__(self, config: BagelTrainingConfig): + super().__init__() + self.config = config + self.gradient_checkpointing = False + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([BagelMoTLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = RotaryEmbedding(config.head_dim, theta=config.rope_theta) + + self.time_embedder = TimestepEmbedder(config.hidden_size) + self.vae2llm = nn.Linear(config.patch_latent_dim, config.hidden_size) + self.llm2vae = nn.Linear(config.hidden_size, config.patch_latent_dim) + self.latent_pos_embed = PositionEmbedding(config.max_latent_size, config.hidden_size) + + def enable_gradient_checkpointing(self, *args, **kwargs): + self.gradient_checkpointing = True + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: Tensor, + timestep: Tensor, + text_token_ids: Optional[Tensor], + latent_pos_ids: Tensor, + **kwargs, + ) -> tuple[Tensor]: + """Forward pass. + + When text_token_ids is None the sequence is [soi, latent, eoi] only + (no text context). This is used for the CFG unconditional pass. + """ + text_attention_mask = kwargs.pop("text_attention_mask", None) + if text_token_ids is not None and text_attention_mask is not None: + text_attention_mask = text_attention_mask.to(device=text_token_ids.device, dtype=torch.bool) + text_lengths = text_attention_mask.sum(dim=-1) + if text_lengths.numel() > 0: + text_length = int(text_lengths.max().item()) + if text_length > 0: + text_token_ids = text_token_ids[:, :text_length] + text_attention_mask = text_attention_mask[:, :text_length] + else: + text_token_ids = None + text_attention_mask = None + + B = hidden_states.shape[0] + L_latent = hidden_states.shape[1] + dev = hidden_states.device + + # 1. Embed text context + if text_token_ids is not None: + text_embeds = self.embed_tokens(text_token_ids) + L_ctx = text_embeds.shape[1] + else: + L_ctx = 0 + text_attention_mask = None + + # 2. SOI / EOI boundary tokens + soi_ids = torch.full((B, 1), self.config.start_of_image_id, dtype=torch.long, device=dev) + eoi_ids = torch.full((B, 1), self.config.end_of_image_id, dtype=torch.long, device=dev) + soi_emb = self.embed_tokens(soi_ids) + eoi_emb = self.embed_tokens(eoi_ids) + + # 3. Latent projection + t_emb = self.time_embedder(timestep) + pos_emb = self.latent_pos_embed(latent_pos_ids) + latent_embeds = self.vae2llm(hidden_states) + t_emb.unsqueeze(1) + pos_emb + latent_embeds = latent_embeds.to(soi_emb.dtype) + + # 4. Sequence: [text?, soi, latent_0..N, eoi] + L_total = L_ctx + 1 + L_latent + 1 + if L_ctx > 0: + sequence = torch.cat([text_embeds, soi_emb, latent_embeds, eoi_emb], dim=1) + else: + sequence = torch.cat([soi_emb, latent_embeds, eoi_emb], dim=1) + + # 5. MoT routing masks + # text pathway: text_ctx + soi + eoi + # gen pathway: latent tokens only + text_mask = torch.zeros(B, L_total, dtype=torch.bool, device=dev) + text_mask[:, : L_ctx + 1] = True # text + soi + text_mask[:, -1] = True # eoi + latent_mask = ~text_mask + + # 6. RoPE positions + if L_ctx > 0: + ctx_pos = torch.arange(L_ctx, device=dev) + img_pos = ctx_pos.new_full((1 + L_latent + 1,), L_ctx) + position_ids = torch.cat([ctx_pos, img_pos]).unsqueeze(0).expand(B, -1) + else: + position_ids = torch.zeros(1, L_total, dtype=torch.long, device=dev).expand(B, -1) + cos, sin = self.rotary_emb(position_ids) + + # 6b. Build per-sequence key padding mask. Only the text segment can + # contain right-padded ``token_id=0`` slots when prompts in a + # micro-batch have different lengths; ``soi``/latent/``eoi`` are always + # valid keys. We pass ``None`` (and SDPA stays on flash backend) when + # there is no padding to mask, which matches the rollout configuration. + if ( + L_ctx > 0 + and text_attention_mask is not None + and not bool(text_attention_mask.all()) + ): + key_padding_mask = text_attention_mask.new_ones(B, L_total, dtype=torch.bool) + key_padding_mask[:, :L_ctx] = text_attention_mask + else: + key_padding_mask = None + + # 7. Transformer layers (split attention: text causal + image full) + for layer in self.layers: + if self.gradient_checkpointing and self.training: + from torch.utils.checkpoint import checkpoint + + def custom_forward( + seq, + cos_, + sin_, + text_mask_, + latent_mask_, + key_padding_mask_, + layer=layer, + ): + return layer( + seq, + cos_, + sin_, + text_mask_, + latent_mask_, + L_ctx, + key_padding_mask=key_padding_mask_, + ) + + sequence = checkpoint( + custom_forward, + sequence, + cos, + sin, + text_mask, + latent_mask, + key_padding_mask, + use_reentrant=False, + ) + else: + sequence = layer( + sequence, + cos, + sin, + text_mask, + latent_mask, + L_ctx, + key_padding_mask=key_padding_mask, + ) + + # 8. Final norm with MoT routing + normed = sequence.new_zeros(sequence.shape) + t_idx = text_mask.nonzero(as_tuple=True) + l_idx = latent_mask.nonzero(as_tuple=True) + normed[t_idx] = self.norm(sequence[t_idx]) + normed[l_idx] = self.norm_moe_gen(sequence[l_idx]) + + # 9. Extract latent output + latent_output = normed[:, L_ctx + 1 : L_ctx + 1 + L_latent, :] + velocity = self.llm2vae(latent_output) + + return (velocity,) + + # ------------------------------------------------------------------ + # PEFT / LoRA compatibility + # ------------------------------------------------------------------ + + def add_adapter(self, adapter_config, adapter_name: str = "default"): + """Add a PEFT LoRA adapter (matches diffusers.ModelMixin API).""" + from peft import inject_adapter_in_model + + inject_adapter_in_model(adapter_config, self, adapter_name) + + def disable_adapters(self): + for module in self.modules(): + if module is self: + continue + disable_adapters = getattr(module, "disable_adapters", None) + if callable(disable_adapters): + disable_adapters() + + def enable_adapters(self): + for module in self.modules(): + if module is self: + continue + enable_adapters = getattr(module, "enable_adapters", None) + if callable(enable_adapters): + enable_adapters() + + @contextmanager + def disable_adapter(self): + try: + self.disable_adapters() + yield + finally: + self.enable_adapters() + + # ------------------------------------------------------------------ + # Checkpoint loading + # ------------------------------------------------------------------ + + @classmethod + def from_pretrained(cls, model_path: str, torch_dtype=torch.bfloat16) -> BagelForTraining: + config = BagelTrainingConfig.from_model_path(model_path) + ckpt_path = os.path.join(model_path, "ema.safetensors") + from safetensors.torch import load_file + + state_dict = load_file(ckpt_path) + + if "latent_pos_embed.pos_embed" in state_dict: + actual_len = state_dict["latent_pos_embed.pos_embed"].shape[0] + grid = int(actual_len**0.5) + if grid * grid == actual_len and grid != config.max_latent_size: + config.max_latent_size = grid + + model = cls(config) + mapped = _map_checkpoint_to_training(state_dict, config) + missing, unexpected = model.load_state_dict(mapped, strict=False) + if missing: + import logging + + logging.getLogger(__name__).warning(f"Missing keys when loading BagelForTraining: {len(missing)} keys") + + model = model.to(torch_dtype) + return model + + +def _map_checkpoint_to_training(state_dict: dict[str, Tensor], config: BagelTrainingConfig) -> dict: + """Map ema.safetensors keys to BagelForTraining parameter names.""" + mapped: dict[str, Tensor] = {} + for src_key, tensor in state_dict.items(): + dst_key: str | None = None + if src_key.startswith("language_model.model."): + dst_key = src_key[len("language_model.model.") :] + elif src_key.startswith("language_model."): + continue + elif src_key.startswith(("time_embedder.", "vae2llm.", "llm2vae.", "latent_pos_embed.")): + dst_key = src_key + if dst_key is not None: + mapped[dst_key] = tensor + return mapped diff --git a/verl_omni/pipelines/bagel_flow_grpo/diffusers_training_adapter.py b/verl_omni/pipelines/bagel_flow_grpo/diffusers_training_adapter.py new file mode 100644 index 00000000..6014abe4 --- /dev/null +++ b/verl_omni/pipelines/bagel_flow_grpo/diffusers_training_adapter.py @@ -0,0 +1,309 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""BAGEL (MoT) training-side adapter for FlowGRPO. + +Registers as ``OmniBagelForConditionalGeneration`` so the FSDP engine +can load and train the model via the DiffusionModelBase registry. + +Key differences from standard diffusion models (e.g. Qwen-Image): + * BAGEL is a *Mixture-of-Thought* transformer that processes text token + IDs and noisy latent patches in a single forward pass (no separate + text encoder). + * ``prompt_embeds`` are not used. Instead the raw prompt token IDs + (available as ``micro_batch["prompts"]``) are passed directly to the + model as ``text_token_ids``. + * CFG must match what the rollout pipeline applied, otherwise the + importance-sampling ratio is biased. We implement BAGEL's 3-branch + CFG with global renormalization here (cfg_img branch == gen branch + for text2img, i.e. text-only 2-branch in practice), exactly mirroring + ``_combine_cfg`` in vllm-omni's ``bagel_transformer``. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +from tensordict import TensorDict +from verl.utils.device import get_device_name + +from verl_omni.pipelines.model_base import DiffusionModelBase +from verl_omni.pipelines.schedulers import FlowMatchSDEDiscreteScheduler +from verl_omni.workers.config import DiffusionModelConfig + +from .bagel_model import BagelForTraining, get_flattened_position_ids +from .vllm_omni_rollout_adapter import BAGEL_FLOWGRPO_CFG_DEFAULTS + +logger = logging.getLogger(__name__) + +TIMESTEP_SHIFT = 3.0 # must match BagelPipeline.forward() hardcoded value + + +@DiffusionModelBase.register("OmniBagelForConditionalGeneration", algorithm="flow_grpo") +class BagelDiffusion(DiffusionModelBase): + """DiffusionModelBase wrapper for ``BagelForTraining`` (MoT).""" + + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype): + logger.info("Loading BagelForTraining from %s", model_config.local_path) + return BagelForTraining.from_pretrained(model_config.local_path, torch_dtype=torch_dtype) + + @classmethod + def build_scheduler(cls, model_config: DiffusionModelConfig): + # Build on GPU so scheduler buffers are comparable with cuda timesteps in FSDP forward. + scheduler = FlowMatchSDEDiscreteScheduler() + cls.set_timesteps(scheduler, model_config, get_device_name()) + return scheduler + + @classmethod + def set_timesteps(cls, scheduler: FlowMatchSDEDiscreteScheduler, model_config: DiffusionModelConfig, device: str): + num_inference_steps = model_config.pipeline.num_inference_steps + # Use torch.float32 on ``device`` to be bit-exact with BAGEL rollout's + # ``torch.linspace`` schedule; otherwise ``index_for_timestep`` may miss. + t = torch.linspace(1, 0, num_inference_steps, dtype=torch.float32, device=device) + t_shifted = TIMESTEP_SHIFT * t / (1 + (TIMESTEP_SHIFT - 1) * t) + sigmas = t_shifted[:-1].tolist() + + scheduler.set_shift(1.0) # identity — sigmas already shifted + # Pass ``timesteps=sigmas`` to skip diffusers' default ``sigmas * 1000`` + # conversion; BAGEL rollout records raw sigma values as timesteps. + scheduler.set_timesteps(sigmas=sigmas, timesteps=sigmas, device=device) + scheduler.set_begin_index(0) + + @classmethod + def _get_latent_pos_ids(cls, model_config: DiffusionModelConfig, module, device) -> torch.Tensor: + """Compute latent position IDs from model config / image dimensions.""" + config = module.config + img_h = model_config.pipeline.height // (config.latent_patch_size * config.vae_downsample) + img_w = model_config.pipeline.width // (config.latent_patch_size * config.vae_downsample) + # Clamp to max_latent_size + img_h = min(img_h, config.max_latent_size) + img_w = min(img_w, config.max_latent_size) + latent_ds = config.latent_patch_size * config.vae_downsample + H_px = img_h * latent_ds + W_px = img_w * latent_ds + pos_ids = get_flattened_position_ids(H_px, W_px, latent_ds, config.max_latent_size) + return pos_ids.to(device) + + @classmethod + def prepare_model_inputs( + cls, + module, + model_config: DiffusionModelConfig, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + micro_batch: TensorDict, + step: int, + ) -> tuple[dict, dict]: + B = latents.shape[0] + device = latents.device + + hidden_states = latents[:, step] + timestep = timesteps[:, step] + + # Extract text token IDs from prompt data + prompts = micro_batch["prompts"] # (B, L_prompt) padded + attention_mask = micro_batch["attention_mask"] # (B, L_prompt) + + # Build per-sample text_token_ids (remove padding) + text_token_ids_list = [] + for i in range(B): + mask = attention_mask[i].bool() + ids = prompts[i][mask] + text_token_ids_list.append(ids) + + # Pad to same length within batch + max_text_len = max(ids.shape[0] for ids in text_token_ids_list) + text_token_ids = torch.zeros(B, max_text_len, dtype=torch.long, device=device) + text_attention_mask = torch.zeros(B, max_text_len, dtype=torch.bool, device=device) + for i, ids in enumerate(text_token_ids_list): + text_token_ids[i, : ids.shape[0]] = ids + text_attention_mask[i, : ids.shape[0]] = True + + # Compute latent position IDs + latent_pos_ids = cls._get_latent_pos_ids(model_config, module, device) + latent_pos_ids = latent_pos_ids.unsqueeze(0).expand(B, -1) + + model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "text_token_ids": text_token_ids, + "text_attention_mask": text_attention_mask, + "latent_pos_ids": latent_pos_ids, + } + + # For BAGEL, unconditional pass uses text_token_ids=None + negative_model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "text_token_ids": None, + "latent_pos_ids": latent_pos_ids, + } + + return model_inputs, negative_model_inputs + + @staticmethod + def _get_cfg_params(model_config: DiffusionModelConfig) -> dict: + """Resolve CFG params for training, preferring values from + ``model_config.pipeline`` and falling back to flow_grpo BAGEL + defaults (same defaults the rollout adapter forces). + + Override examples (Hydra/OmegaConf, both rollout *and* model side + must be set together if you change them): + +actor_rollout_ref.model.pipeline.cfg_text_scale=4.0 + +actor_rollout_ref.model.pipeline.cfg_img_scale=1.0 + +actor_rollout_ref.model.pipeline.cfg_renorm_type=global + +actor_rollout_ref.model.pipeline.cfg_renorm_min=0.0 + +actor_rollout_ref.model.pipeline.cfg_interval=[0,1.0] + """ + p = model_config.pipeline + cfg_interval = getattr(p, "cfg_interval", BAGEL_FLOWGRPO_CFG_DEFAULTS["cfg_interval"]) + if isinstance(cfg_interval, (list, tuple)) and len(cfg_interval) == 2: + interval_low, interval_high = float(cfg_interval[0]), float(cfg_interval[1]) + else: + interval_low, interval_high = 0.0, 1.0 + return { + "cfg_text_scale": float(getattr(p, "cfg_text_scale", BAGEL_FLOWGRPO_CFG_DEFAULTS["cfg_text_scale"])), + "cfg_img_scale": float(getattr(p, "cfg_img_scale", BAGEL_FLOWGRPO_CFG_DEFAULTS["cfg_img_scale"])), + "cfg_renorm_type": str(getattr(p, "cfg_renorm_type", BAGEL_FLOWGRPO_CFG_DEFAULTS["cfg_renorm_type"])), + "cfg_renorm_min": float(getattr(p, "cfg_renorm_min", BAGEL_FLOWGRPO_CFG_DEFAULTS["cfg_renorm_min"])), + "cfg_interval_low": interval_low, + "cfg_interval_high": interval_high, + } + + @staticmethod + def _combine_cfg( + v_t: torch.Tensor, + cfg_text_v_t: torch.Tensor, + cfg_img_v_t: Optional[torch.Tensor], + cfg_text_scale: float, + cfg_img_scale: float, + cfg_renorm_type: str, + cfg_renorm_min: float, + ) -> torch.Tensor: + """Byte-identical port of + ``vllm_omni.diffusion.models.bagel.bagel_transformer.BagelTransformer + ._combine_cfg`` so that training-time velocity matches what the + rollout actually used to generate the recorded trajectory. + + For text2img there is no input image, so the rollout's cfg_img + branch is fed the same conditioning as the gen branch; callers + therefore pass ``cfg_img_v_t = v_t`` (or ``None`` to skip it + entirely when ``cfg_img_scale == 1.0``). + """ + if cfg_renorm_type == "text_channel": + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True) + scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + v_t_text = v_t_text_ * scale + if cfg_img_scale > 1.0 and cfg_img_v_t is not None: + return cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t) + return v_t_text + + v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t) + if cfg_img_scale > 1.0 and cfg_img_v_t is not None: + v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t) + else: + v_t_ = v_t_text_ + + if cfg_renorm_type == "global": + # vLLM-Omni/BAGEL rollout handles one image per request, so its + # "global" renorm is global over latent tokens/channels for each + # sample. Training is batched; keep samples independent instead + # of mixing the whole micro-batch into one scalar norm. + norm_dims = tuple(range(1, v_t.ndim)) + norm_v_t = torch.linalg.vector_norm(v_t, dim=norm_dims, keepdim=True) + norm_v_t_ = torch.linalg.vector_norm(v_t_, dim=norm_dims, keepdim=True) + elif cfg_renorm_type == "channel": + norm_v_t = torch.norm(v_t, dim=-1, keepdim=True) + norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True) + else: + raise NotImplementedError(f"cfg_renorm_type={cfg_renorm_type!r} is not supported") + + scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(min=cfg_renorm_min, max=1.0) + return v_t_ * scale + + @classmethod + def forward_and_sample_previous_step( + cls, + module, + scheduler: FlowMatchSDEDiscreteScheduler, + model_config: DiffusionModelConfig, + model_inputs: dict[str, torch.Tensor], + negative_model_inputs: Optional[dict[str, torch.Tensor]], + scheduler_inputs: Optional[TensorDict | dict[str, torch.Tensor]], + step: int, + ): + assert scheduler_inputs is not None + latents = scheduler_inputs["all_latents"] + timesteps = scheduler_inputs["all_timesteps"] + + # Gen branch (text-conditional). + noise_pred = module(**model_inputs)[0] + + # --------------------------------------------------------------- # + # Bug 2 fix: # + # Apply BAGEL CFG matching the rollout so the importance-sampling # + # ratio in compute_diffusion_loss_flow_grpo is unbiased. The # + # previous implementation used a simple 2-branch CFG with per-token # + # renormalization gated by ``true_cfg_scale > 1.0`` (which is the # + # default 1.0, so CFG was *off*). The rollout, however, always # + # ran with cfg_text_scale=4.0 + global renorm; the resulting # + # mismatch silently biased the policy gradient. # + # --------------------------------------------------------------- # + cfg = cls._get_cfg_params(model_config) + # sigma at this denoising step (same for the entire batch in BAGEL) + sigma_now = float(timesteps[0, step].item()) + in_cfg_interval = sigma_now > cfg["cfg_interval_low"] and sigma_now <= cfg["cfg_interval_high"] + apply_cfg = in_cfg_interval and cfg["cfg_text_scale"] > 1.0 + + if apply_cfg: + assert negative_model_inputs is not None, ( + "BAGEL CFG requires negative_model_inputs (text-unconditional branch)." + ) + # cfg_text branch: text_token_ids=None -> empty text context. + cfg_text_pred = module(**negative_model_inputs)[0] + # For text2img, no input image was supplied to drop, so the + # cfg_img branch is identical to the gen branch and we can + # reuse ``noise_pred`` instead of running a third forward. + cfg_img_pred = noise_pred if cfg["cfg_img_scale"] > 1.0 else None + + noise_pred = cls._combine_cfg( + v_t=noise_pred, + cfg_text_v_t=cfg_text_pred, + cfg_img_v_t=cfg_img_pred, + cfg_text_scale=cfg["cfg_text_scale"], + cfg_img_scale=cfg["cfg_img_scale"], + cfg_renorm_type=cfg["cfg_renorm_type"], + cfg_renorm_min=cfg["cfg_renorm_min"], + ) + + _, log_prob, prev_sample_mean, std_dev_t, sqrt_dt = scheduler.sample_previous_step( + sample=latents[:, step].float(), + model_output=noise_pred.float(), + timestep=timesteps[:, step], + noise_level=model_config.algo.noise_level, + prev_sample=latents[:, step + 1].float(), + sde_type=model_config.algo.sde_type, + return_logprobs=True, + return_sqrt_dt=True, + ) + return log_prob, prev_sample_mean, std_dev_t, sqrt_dt diff --git a/verl_omni/pipelines/bagel_flow_grpo/vllm_omni_rollout_adapter.py b/verl_omni/pipelines/bagel_flow_grpo/vllm_omni_rollout_adapter.py new file mode 100644 index 00000000..71bdef88 --- /dev/null +++ b/verl_omni/pipelines/bagel_flow_grpo/vllm_omni_rollout_adapter.py @@ -0,0 +1,430 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom vllm-omni pipeline for BAGEL RL rollouts with verl-omni. + +Extends :class:`BagelPipeline` to: +* Replace the scheduler with an SDE scheduler for stochastic denoising + with log-probability recording. +* Always enable trajectory recording. +* Read SDE kwargs from ``sampling_params.extra_args``. +* Return RL artifacts in ``DiffusionOutput.custom_output``. +""" + +from __future__ import annotations + +import hashlib +import logging +import random +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +import torch +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from verl_omni.pipelines.model_base import VllmOmniPipelineBase +from verl_omni.pipelines.schedulers import FlowMatchSDEDiscreteScheduler + +logger = logging.getLogger(__name__) + + +# CFG defaults aligned with the original flow_grpo BAGEL training script +# (flow_grpo/scripts/train_bagel.py: inference_hyper) so that rollout (the +# behavior policy) and training-time log-prob recomputation use exactly the +# same CFG. Any mismatch here biases the policy gradient because the +# trajectory was generated by a CFG-amplified velocity, while training would +# otherwise score it under a different (or non-)CFG velocity. +# +# Original flow_grpo Bagel uses: +# cfg_text_scale = config.sample.guidance_scale (= 4.0) +# cfg_img_scale = 1.0 (disables img branch) +# cfg_interval = [0, 1.0] (always-on) +# cfg_renorm_type = "global" +# cfg_renorm_min = 0.0 +# +# NB: vllm-omni's BagelPipeline defaults are different (cfg_img_scale=1.5, +# cfg_interval=(0.4, 1.0)). We deliberately override them here. +BAGEL_FLOWGRPO_CFG_DEFAULTS: dict[str, Any] = { + "cfg_text_scale": 4.0, + "cfg_img_scale": 1.0, + "cfg_interval": (0.0, 1.0), + "cfg_renorm_type": "global", + "cfg_renorm_min": 0.0, +} + + +_CHAT_MARKERS = ( + "<|vision_start|>", + "<|vision_end|>", + "<|image_pad|>", + "<|video_pad|>", +) + + +def _to_token_list(token_ids: Any) -> list[int] | None: + if token_ids is None: + return None + if isinstance(token_ids, torch.Tensor): + token_ids = token_ids.detach().cpu().tolist() + if token_ids and isinstance(token_ids[0], list): + token_ids = token_ids[0] + return [int(token_id) for token_id in token_ids] + + +def _extract_prompt_text(decoded: str) -> str: + if "<|im_start|>" in decoded: + user_chunks = [] + for segment in decoded.split("<|im_start|>"): + if not segment.startswith("user"): + continue + content = segment[len("user") :].lstrip("\n") + content = content.split("<|im_end|>", 1)[0] + user_chunks.append(content) + if user_chunks: + decoded = user_chunks[-1] + + for marker in _CHAT_MARKERS: + decoded = decoded.replace(marker, "") + return decoded.replace("<|im_start|>", "").replace("<|im_end|>", "").strip() + + +def _to_cpu_tensor(v): + """Convert to a single CPU tensor, stacking a list of tensors if needed.""" + if isinstance(v, torch.Tensor): + return v.detach().cpu() + if isinstance(v, list): + tensors = [x.detach().cpu() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in v] + return torch.stack(tensors) if tensors else None + return v + + +@dataclass +class _AdapterStepOutput: + """Adapter output matching what bagel_transformer.generate_image expects.""" + + prev_sample: torch.Tensor + log_prob: torch.Tensor | None + + +class _BagelSchedulerAdapter: + """Wraps the diffusers-based FlowMatchSDEDiscreteScheduler to match + BAGEL's calling convention: ``step(v_t, sigma, x_t, dt, **kwargs)``. + + BAGEL's transformer calls ``scheduler.step(model_output, timesteps[i], + sample, dts[i], **scheduler_kwargs)`` with 4 positional args, while the + diffusers scheduler takes ``step(model_output, timestep, sample, **kwargs)`` + and computes dt internally. This adapter bridges the gap. + + Stateful SDE-windowing + ---------------------- + Beyond the calling-convention adapter, this class also implements the + "SDE window" behavior from flow_grpo's original BAGEL rollout + (flow_grpo/flow_grpo/bagel/modeling/bagel/bagel.py::generate_image): + noise is injected only on a contiguous window of denoising steps + ``[window_begin, window_begin + window_size)``, while steps outside the + window are run deterministically (ODE, ``noise_level=0``). Log-prob + recording is also gated to the window, otherwise outside-window steps + would produce ``-inf`` / ``NaN`` log-probs (``std_dev_t == 0``). + + Without this windowing, ``noise_level`` would be injected at every step + (e.g. 15 steps with ``noise_level=1.2``), the recorded latents would be + far off the deterministic ODE manifold, rewards would flatten, and the + PPO policy gradient would degenerate. See the bug analysis in the + chat history for the failure mode. + """ + + def __init__(self, inner: FlowMatchSDEDiscreteScheduler): + self._inner = inner + # Per-rollout state, reset via ``begin_forward``. + self._sde_window: Optional[tuple[int, int]] = None + self._base_noise_level: float = 0.0 + self._base_return_logprobs: bool = True + self._step_counter: int = 0 + + def __getattr__(self, name): + return getattr(self._inner, name) + + def begin_forward( + self, + sde_window: Optional[tuple[int, int]], + noise_level: float, + return_logprobs: bool, + ) -> None: + """Reset adapter state before each rollout ``forward`` call. + + Args: + sde_window: ``(begin, end_exclusive)`` step range where SDE + noise is injected and log-probs are recorded. ``None`` + disables windowing (legacy behavior: noise at every step). + noise_level: SDE noise level to apply inside the window. + return_logprobs: whether log-probs are requested at all + (overridden to ``False`` outside the window even when + ``True`` here). + """ + self._sde_window = sde_window + self._base_noise_level = float(noise_level) + self._base_return_logprobs = bool(return_logprobs) + self._step_counter = 0 + + def step( + self, + model_output: torch.Tensor, + sigma: float | torch.Tensor, + sample: torch.Tensor, + dt: float | torch.Tensor, # noqa: ARG002 — inner derives dt from timestep schedule + **kwargs, + ) -> _AdapterStepOutput: + i = self._step_counter + if self._sde_window is not None: + begin, end = self._sde_window + in_window = begin <= i < end + # Inside window -> use configured noise_level & record log_prob. + # Outside window -> deterministic ODE; skip log_prob (otherwise + # std_dev_t = 0 -> log(0) = -inf in sample_previous_step). + cur_noise_level = self._base_noise_level if in_window else 0.0 + cur_return_logprobs = self._base_return_logprobs and in_window + kwargs = { + **kwargs, + "noise_level": cur_noise_level, + "return_logprobs": cur_return_logprobs, + } + # else: pass caller kwargs through unchanged (legacy behavior). + + out = self._inner.step( + model_output=model_output, + timestep=sigma, + sample=sample, + return_dict=False, + **kwargs, + ) + self._step_counter += 1 + # step() with return_dict=False returns + # (prev_sample, log_prob, prev_sample_mean, std_dev_t) + prev_sample, log_prob = out[0], out[1] + # BAGEL packs latents as ``(num_tokens, channels)`` per request (no + # explicit batch dim, since rollout handles one sample at a time). + # The inner scheduler's ``mean(dim=tuple(range(1, ndim)))`` therefore + # preserves the spatial-token dim, yielding ``(num_tokens,)`` instead + # of a scalar. At training time, latents *do* have a batch dim + # (``(B, num_tokens, channels)``), so the same reduction yields + # ``(B,)``. Without the extra reduction here, ``old_log_probs`` ends + # up with shape ``(B, num_steps, num_tokens)`` while training's + # ``log_prob`` is ``(B,)`` — broadcasting then fails inside + # ``compute_diffusion_loss_flow_grpo``. Match the original flow_grpo + # behavior (``log_prob.mean()`` in ``bagel.py::_sde_step_with_logprob``) + # by fully reducing to a per-step scalar. + if log_prob is not None: + log_prob = log_prob.mean() + return _AdapterStepOutput(prev_sample=prev_sample, log_prob=log_prob) + + +def _pick_sde_window( + window_size: Optional[int], + window_range: Optional[Any], + seed: Optional[int], + request_id: Optional[str], +) -> Optional[tuple[int, int]]: + """Pick a contiguous SDE window ``[begin, begin + window_size)`` randomly + inside ``window_range`` (inclusive). Returns ``None`` when windowing is + disabled (``window_size`` is ``None`` or 0). + + Reproducibility: + * ``seed != None`` -> seed Python RNG with ``seed``. + * ``request_id`` provided -> seed RNG with sha256(request_id) so + concurrent requests inside one process still get different windows. + * Both ``None`` -> use default RNG (non-deterministic). + + This mirrors flow_grpo's ``random.randint(window_range[0], + window_range[1] - window_size)`` call but is seeded per-request rather + than per-process so that different rollouts inside one replica explore + different parts of the trajectory. + """ + if window_size is None or int(window_size) <= 0: + return None + if window_range is None: + return (0, int(window_size)) + + low = int(window_range[0]) + high = int(window_range[1]) + high_inclusive = high - int(window_size) + if high_inclusive < low: + # Window doesn't fit; clamp to the lowest valid begin. + return (low, low + int(window_size)) + + if seed is not None: + rng = random.Random(int(seed)) + elif request_id is not None: + h = hashlib.sha256(str(request_id).encode()).digest() + rng = random.Random(int.from_bytes(h[:8], "big")) + else: + rng = random.Random() + begin = rng.randint(low, high_inclusive) + return (begin, begin + int(window_size)) + + +@VllmOmniPipelineBase.register("OmniBagelForConditionalGeneration", algorithm="flow_grpo") +class BagelPipelineWithLogProb(BagelPipeline): + """BAGEL pipeline variant for RL rollouts with verl-omni.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + inner = FlowMatchSDEDiscreteScheduler() + self.scheduler = _BagelSchedulerAdapter(inner) + logger.info("BagelPipelineWithLogProb: SDE scheduler enabled for RL rollouts") + + def _decode_token_prompt(self, token_ids: Any) -> str | None: + token_list = _to_token_list(token_ids) + if not token_list: + return None + decoded = self.tokenizer.decode(token_list, skip_special_tokens=False) + return _extract_prompt_text(decoded) + + def _ensure_bagel_prompt_text(self, req: OmniDiffusionRequest) -> None: + if not req.prompts or not isinstance(req.prompts[0], dict): + return + + custom_prompt = req.prompts[0] + if not custom_prompt.get("prompt"): + prompt = self._decode_token_prompt(custom_prompt.get("prompt_token_ids")) + if prompt is not None: + custom_prompt["prompt"] = prompt + + extra_args = req.sampling_params.extra_args + if "negative_prompt" not in extra_args: + negative_prompt = self._decode_token_prompt(custom_prompt.get("negative_prompt_ids")) + if negative_prompt is not None: + extra_args["negative_prompt"] = negative_prompt + + prompt_extra_args = custom_prompt.get("extra_args") + if isinstance(prompt_extra_args, dict): + multi_modal_data = prompt_extra_args.get("multi_modal_data") + if multi_modal_data is not None and "multi_modal_data" not in custom_prompt: + custom_prompt["multi_modal_data"] = multi_modal_data + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + self._ensure_bagel_prompt_text(req) + + # Force trajectory recording on for RL + req.sampling_params.return_trajectory_latents = True + + extra_args = req.sampling_params.extra_args + + # ----------------------------------------------------------------- # + # Bug 2 fix (rollout side): # + # Force the BAGEL CFG kwargs that BagelPipeline.forward reads from # + # ``extra_args`` to flow_grpo's training defaults *unless* the # + # caller has explicitly set them. This keeps the rollout (behavior # + # policy) byte-identical with the CFG that the training adapter # + # applies when recomputing log-probs, so the importance-sampling # + # ratio is not silently biased by a CFG mismatch. # + # ----------------------------------------------------------------- # + for k, v in BAGEL_FLOWGRPO_CFG_DEFAULTS.items(): + extra_args.setdefault(k, v) + # OmegaConf delivers tuples as lists; BagelPipeline expects a tuple. + if isinstance(extra_args.get("cfg_interval"), list): + extra_args["cfg_interval"] = tuple(extra_args["cfg_interval"]) + + # ----------------------------------------------------------------- # + # Bug 1 fix: SDE windowing. # + # Pick a contiguous denoising-step window inside which we apply SDE # + # noise and record (latent, log_prob, sigma) triplets; the rest of # + # the trajectory runs deterministically (ODE). This mirrors # + # generate_image() in flow_grpo's bagel.py and is what the training # + # loss expects when iterating ``range(num_recorded_steps)``. # + # ----------------------------------------------------------------- # + logprobs = bool(extra_args.get("logprobs", True)) + noise_level = float(extra_args.get("noise_level", 0.0)) + sde_window_size = extra_args.get("sde_window_size", None) + sde_window_range = extra_args.get("sde_window_range", None) + if isinstance(sde_window_range, list): + sde_window_range = tuple(sde_window_range) + + sde_window: Optional[tuple[int, int]] = None + if sde_window_size and noise_level > 0.0: + sde_window = _pick_sde_window( + window_size=int(sde_window_size), + window_range=sde_window_range, + seed=req.sampling_params.seed, + request_id=getattr(req, "request_id", None), + ) + + # Scheduler kwargs passed to every step. ``_BagelSchedulerAdapter`` + # overrides ``noise_level`` and ``return_logprobs`` per-step based + # on whether the step is inside ``sde_window``; we still pass these + # defaults so the no-window legacy path works. + self.scheduler_kwargs = { + k: extra_args[k] for k in ("noise_level", "sde_type", "generator") if k in extra_args + } + self.scheduler_kwargs["return_logprobs"] = logprobs + + # Per-request scheduler setup: compute BAGEL's shifted sigmas so + # the inner SDE scheduler's sigma schedule matches what + # generate_image() computes internally. + assert req.sampling_params.num_inference_steps is not None, "num_inference_steps must be set for RL rollouts" + num_timesteps = req.sampling_params.num_inference_steps + timestep_shift = 3.0 # must match BagelPipeline.forward() hardcoded value + + t = np.linspace(1, 0, num_timesteps) + t_shifted = timestep_shift * t / (1 + (timestep_shift - 1) * t) + sigmas = t_shifted[:-1].tolist() # drop terminal 0; set_timesteps appends it + + inner = self.scheduler._inner + inner.set_shift(1.0) # identity — sigmas already shifted + inner.set_timesteps(sigmas=sigmas) + inner.set_begin_index(0) + + # Reset the stateful adapter for this request (must happen *after* + # ``set_timesteps`` so that the inner step_index is None again). + self.scheduler.begin_forward( + sde_window=sde_window, + noise_level=noise_level, + return_logprobs=logprobs, + ) + + output = super().forward(req) + + # ----------------------------------------------------------------- # + # Slice the recorded trajectory to the SDE window so the training # + # adapter iterates only on steps where noise was actually injected. # + # ``trajectory_log_probs`` is already window-length (None entries # + # are dropped in generate_image), but ``trajectory_latents`` and # + # ``trajectory_timesteps`` record *all* steps unconditionally. # + # ----------------------------------------------------------------- # + traj_latents = output.trajectory_latents + traj_timesteps = output.trajectory_timesteps + traj_log_probs = output.trajectory_log_probs + + if sde_window is not None: + begin, end = sde_window + if traj_latents is not None: + # shape: (num_steps + 1, ...); keep x at sigma_begin .. sigma_end + traj_latents = traj_latents[begin : end + 1] + if traj_timesteps is not None: + # shape: (num_steps,); keep sigma_begin .. sigma_{end-1} + traj_timesteps = traj_timesteps[begin:end] + # traj_log_probs already has length (end - begin) thanks to + # _BagelSchedulerAdapter gating; no slicing needed. + + custom = output.custom_output or {} + if traj_latents is not None: + custom["all_latents"] = _to_cpu_tensor(traj_latents) + if traj_timesteps is not None: + custom["all_timesteps"] = _to_cpu_tensor(traj_timesteps) + if traj_log_probs is not None: + custom["all_log_probs"] = _to_cpu_tensor(traj_log_probs) + output.custom_output = custom + + return output diff --git a/verl_omni/pipelines/model_base.py b/verl_omni/pipelines/model_base.py index 36455702..ffec351d 100644 --- a/verl_omni/pipelines/model_base.py +++ b/verl_omni/pipelines/model_base.py @@ -76,6 +76,16 @@ def get_class(cls, model_config: DiffusionModelConfig) -> type["DiffusionModelBa f"Set ``external_lib`` in DiffusionModelConfig to load your implementation." ) from None + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: + """Optional hook for custom model loading. + + Override this to load non-standard models (e.g. models not loadable + via ``diffusers.AutoModel``). Return ``None`` to fall back to the + default ``AutoModel.from_pretrained`` path in the FSDP engine. + """ + return None + @classmethod @abstractmethod def build_scheduler(cls, model_config: DiffusionModelConfig) -> SchedulerMixin: diff --git a/verl_omni/pipelines/qwen_image_flow_grpo/vllm_omni_rollout_adapter.py b/verl_omni/pipelines/qwen_image_flow_grpo/vllm_omni_rollout_adapter.py index aeb73013..34fa2a7b 100644 --- a/verl_omni/pipelines/qwen_image_flow_grpo/vllm_omni_rollout_adapter.py +++ b/verl_omni/pipelines/qwen_image_flow_grpo/vllm_omni_rollout_adapter.py @@ -66,20 +66,20 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): def _get_qwen_prompt_embeds( self, - prompt_ids: torch.Tensor, + prompt_token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, dtype: torch.dtype | None = None, ): dtype = dtype or self.text_encoder.dtype if attention_mask is None: - attention_mask = torch.ones_like(prompt_ids, dtype=torch.long) + attention_mask = torch.ones_like(prompt_token_ids, dtype=torch.long) - prompt_ids = prompt_ids.unsqueeze(0) if prompt_ids.ndim == 1 else prompt_ids + prompt_token_ids = prompt_token_ids.unsqueeze(0) if prompt_token_ids.ndim == 1 else prompt_token_ids attention_mask = attention_mask.unsqueeze(0) if attention_mask.ndim == 1 else attention_mask drop_idx = self.prompt_template_encode_start_idx encoder_hidden_states = self.text_encoder( - input_ids=prompt_ids.to(self.device), + input_ids=prompt_token_ids.to(self.device), attention_mask=attention_mask.to(self.device), output_hidden_states=True, ) @@ -101,7 +101,7 @@ def _get_qwen_prompt_embeds( def encode_prompt( self, - prompt_ids: torch.Tensor, + prompt_token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, num_images_per_prompt: int = 1, prompt_embeds: torch.Tensor | None = None, @@ -111,13 +111,13 @@ def encode_prompt( """Encode text prompt token IDs into dense embeddings. Args: - prompt_ids (torch.Tensor): Token IDs of shape ``(B, L)`` or ``(L,)``. + prompt_token_ids (torch.Tensor): Token IDs of shape ``(B, L)`` or ``(L,)``. attention_mask (torch.Tensor, *optional*): Boolean mask of shape - ``(B, L)`` for *prompt_ids*; inferred as all-ones when ``None``. + ``(B, L)`` for *prompt_token_ids*; inferred as all-ones when ``None``. num_images_per_prompt (int): Number of images to generate per prompt; embeddings are repeated accordingly. prompt_embeds (torch.Tensor, *optional*): Pre-computed embeddings; - when provided *prompt_ids* is ignored. + when provided *prompt_token_ids* is ignored. prompt_embeds_mask (torch.Tensor, *optional*): Attention mask for pre-computed *prompt_embeds*. max_sequence_length (int): Maximum sequence length; embeddings are @@ -129,13 +129,15 @@ def encode_prompt( ``(B * num_images_per_prompt, L, D)`` and ``(B * num_images_per_prompt, L)`` respectively. """ - prompt_ids = prompt_ids.unsqueeze(0) if prompt_ids.ndim == 1 else prompt_ids + prompt_token_ids = prompt_token_ids.unsqueeze(0) if prompt_token_ids.ndim == 1 else prompt_token_ids attention_mask = ( attention_mask.unsqueeze(0) if attention_mask is not None and attention_mask.ndim == 1 else attention_mask ) if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt_ids, attention_mask=attention_mask) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt_token_ids, attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -276,7 +278,7 @@ def diffuse( def forward( self, req: OmniDiffusionRequest, - prompt_ids: torch.Tensor | list[int] | None = None, + prompt_token_ids: torch.Tensor | list[int] | None = None, prompt_mask: torch.Tensor | None = None, negative_prompt_ids: torch.Tensor | list[int] | None = None, negative_prompt_mask: torch.Tensor | None = None, @@ -312,9 +314,9 @@ def forward( Args: req (OmniDiffusionRequest): Rollout request containing prompts and :class:`~vllm_omni.diffusion.data.OmniDiffusionSamplingParams`. - prompt_ids (torch.Tensor | list[int], *optional*): Token IDs for + prompt_token_ids (torch.Tensor | list[int], *optional*): Token IDs for the positive prompt. - prompt_mask (torch.Tensor, *optional*): Attention mask for *prompt_ids*. + prompt_mask (torch.Tensor, *optional*): Attention mask for *prompt_token_ids*. negative_prompt_ids (torch.Tensor | list[int], *optional*): Token IDs for the negative prompt used in True-CFG. negative_prompt_mask (torch.Tensor, *optional*): Attention mask for @@ -364,7 +366,7 @@ def forward( """ custom_prompt = req.prompts[0] if req.prompts else {} if isinstance(custom_prompt, dict): - prompt_ids = custom_prompt.get("prompt_ids", prompt_ids) + prompt_token_ids = custom_prompt.get("prompt_token_ids", prompt_token_ids) prompt_mask = custom_prompt.get("prompt_mask", prompt_mask) negative_prompt_ids = custom_prompt.get("negative_prompt_ids", negative_prompt_ids) negative_prompt_mask = custom_prompt.get("negative_prompt_mask", negative_prompt_mask) @@ -396,14 +398,14 @@ def forward( self._current_timestep = None self._interrupt = False - if prompt_ids is not None: - if isinstance(prompt_ids, list): - prompt_ids = torch.tensor(prompt_ids, device=self.device) - batch_size = prompt_ids.shape[0] if prompt_ids.ndim == 2 else 1 + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + prompt_token_ids = torch.tensor(prompt_token_ids, device=self.device) + batch_size = prompt_token_ids.shape[0] if prompt_token_ids.ndim == 2 else 1 elif prompt_embeds is not None: batch_size = prompt_embeds.shape[0] else: - # Both prompt_ids and prompt_embeds are None (e.g. during warmup/dummy run). + # Both prompt_token_ids and prompt_embeds are None (e.g. during warmup/dummy run). # Return a minimal dummy output to avoid crashing. return DiffusionOutput(output=None, custom_output={}) @@ -416,7 +418,7 @@ def forward( do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( - prompt_ids=prompt_ids, + prompt_token_ids=prompt_token_ids, attention_mask=prompt_mask, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, @@ -425,7 +427,7 @@ def forward( ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - prompt_ids=negative_prompt_ids, + prompt_token_ids=negative_prompt_ids, attention_mask=negative_prompt_mask, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, @@ -516,12 +518,16 @@ def forward( return DiffusionOutput( output=_maybe_to_cpu(image), custom_output={ - "all_latents": _maybe_to_cpu(all_latents), - "all_log_probs": _maybe_to_cpu(all_log_probs), - "all_timesteps": _maybe_to_cpu(all_timesteps), - "prompt_embeds": _maybe_to_cpu(prompt_embeds), - "prompt_embeds_mask": _maybe_to_cpu(prompt_embeds_mask), - "negative_prompt_embeds": _maybe_to_cpu(negative_prompt_embeds), - "negative_prompt_embeds_mask": _maybe_to_cpu(negative_prompt_embeds_mask), + "all_latents": _maybe_to_cpu(all_latents[0]), + "all_log_probs": _maybe_to_cpu(all_log_probs[0]) if all_log_probs is not None else None, + "all_timesteps": _maybe_to_cpu(all_timesteps[0]), + "prompt_embeds": _maybe_to_cpu(prompt_embeds[0]), + "prompt_embeds_mask": _maybe_to_cpu(prompt_embeds_mask[0]) if prompt_embeds_mask is not None else None, + "negative_prompt_embeds": _maybe_to_cpu(negative_prompt_embeds[0]) + if negative_prompt_embeds is not None + else None, + "negative_prompt_embeds_mask": _maybe_to_cpu(negative_prompt_embeds_mask[0]) + if negative_prompt_embeds_mask is not None + else None, }, ) diff --git a/verl_omni/trainer/diffusion/ray_diffusion_trainer.py b/verl_omni/trainer/diffusion/ray_diffusion_trainer.py index 954e3a18..6600d9e8 100644 --- a/verl_omni/trainer/diffusion/ray_diffusion_trainer.py +++ b/verl_omni/trainer/diffusion/ray_diffusion_trainer.py @@ -28,6 +28,7 @@ import torch from omegaconf import OmegaConf, open_dict from PIL import Image +from tensordict import TensorDict from torch.utils.data import Dataset, Sampler from torchdata.stateful_dataloader import StatefulDataLoader from tqdm import tqdm @@ -60,6 +61,47 @@ from verl_omni.workers.utils.padding import embeds_padding_2_no_padding +def _patch_reward_loop_workers_with_gpu() -> None: + """Let CLIP-based diffusion reward fns (e.g. PickScore) run on a GPU. + + Upstream ``RewardLoopManager`` creates Ray actors with no GPU, which is + fine for text/LLM reward but kills throughput for CLIP-class image rewards. + Set ``VERL_OMNI_REWARD_WORKER_NUM_GPUS=`` to claim that many GPUs + per reward worker. + """ + from verl.experimental.reward_loop import reward_loop as _rl_mod + + if getattr(_rl_mod.RewardLoopManager, "_diffusion_gpu_patched", False): + return + + _orig = _rl_mod.RewardLoopManager._init_reward_loop_workers + + def _init_with_gpu(self): + num_gpus = float(os.environ.get("VERL_OMNI_REWARD_WORKER_NUM_GPUS", "0")) + if num_gpus <= 0: + return _orig(self) + self.reward_loop_workers = [] + n_workers = self.config.reward.num_workers + node_ids = [n["NodeID"] for n in ray.nodes() if n["Alive"] and n["Resources"].get("CPU", 0) > 0] + for i in range(n_workers): + node_id = node_ids[i % len(node_ids)] + self.reward_loop_workers.append( + self.reward_loop_workers_class.options( + name=f"reward_loop_worker_{i}", + num_gpus=num_gpus, + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=node_id, soft=True, + ), + ).remote(self.config, self.reward_router_address) + ) + + _rl_mod.RewardLoopManager._init_reward_loop_workers = _init_with_gpu + _rl_mod.RewardLoopManager._diffusion_gpu_patched = True + + +_patch_reward_loop_workers_with_gpu() + + def compute_advantage( data: DataProto, adv_estimator: str, @@ -106,6 +148,25 @@ def compute_advantage( return data +def compute_logprob_alignment_metrics(batch: DataProto) -> dict[str, float]: + """Compare rollout-time log-probs with trainer-recomputed old log-probs. + + For Flow-GRPO the old-policy log-prob should match the log-prob recorded + during rollout before any actor update. A persistent offset here means the + rollout and training code paths are still not scoring the same trajectory. + """ + if "rollout_log_probs" not in batch.batch or "old_log_probs" not in batch.batch: + return {} + + diff = (batch.batch["old_log_probs"].float() - batch.batch["rollout_log_probs"].float()).detach() + return { + "debug/logprob_alignment/mean": diff.mean().item(), + "debug/logprob_alignment/mean_abs": diff.abs().mean().item(), + "debug/logprob_alignment/max_abs": diff.abs().max().item(), + "debug/logprob_alignment/std": diff.std().item(), + } + + class RayFlowGRPOTrainer: """Distributed Flow-GRPO trainer using Ray for scalable reinforcement learning. @@ -379,13 +440,46 @@ def _get_gen_batch(self, batch: DataProto) -> DataProto: return gen_batch - def _compute_reward_colocate(self, batch: DataProto) -> tuple[torch.Tensor, dict[str, Any]] | torch.Tensor: - """ - compute reward use colocate reward model + def _compute_reward_colocate(self, batch: DataProto) -> DataProto: + """Compute per-sample diffusion reward via the colocated reward loop. + + Bypasses ``RewardLoopManager.compute_rm_score`` (LLM-only: assumes + ``responses`` has a token axis and reads ``attention_mask``) and + assembles a ``[B, 1]`` ``rm_scores`` tensor directly. """ assert self.reward_loop_manager is not None, "RewardLoopManager is None" - batch_reward = self.reward_loop_manager.compute_rm_score(batch) - return batch_reward + manager = self.reward_loop_manager + + if manager.reward_model_manager is not None: + manager.reward_model_manager.wake_up() + + chunks = batch.chunk(len(manager.reward_loop_workers)) + outputs = ray.get( + [ + worker.compute_score_batch.remote(chunk) + for worker, chunk in zip(manager.reward_loop_workers, chunks, strict=True) + ] + ) + outputs_flat = [item for sublist in outputs for item in sublist] + + scores = [item["reward_score"] for item in outputs_flat] + rm_scores = torch.tensor(scores, dtype=torch.float32).unsqueeze(-1) + reward_batch = TensorDict({"rm_scores": rm_scores}, batch_size=len(batch)) + + reward_extra_infos = [output.get("reward_extra_info", {}) for output in outputs_flat] + reward_extra_keys = list(reward_extra_infos[0].keys()) if reward_extra_infos else [] + non_tensor_batch = { + key: np.array([info[key] for info in reward_extra_infos]) for key in reward_extra_keys + } + + if manager.reward_model_manager is not None: + manager.reward_model_manager.sleep() + + return DataProto( + batch=reward_batch, + non_tensor_batch=non_tensor_batch, + meta_info={"reward_extra_keys": reward_extra_keys}, + ) def _validate(self): data_source_lst = [] @@ -968,6 +1062,7 @@ def fit(self): batch = batch.union(old_log_prob) assert "old_log_probs" in batch.batch, f'"old_log_prob" not in {batch.batch.keys()=}' + metrics.update(compute_logprob_alignment_metrics(batch)) if self.use_reference_policy: # compute reference log_prob diff --git a/verl_omni/workers/engine/fsdp/diffusers_impl.py b/verl_omni/workers/engine/fsdp/diffusers_impl.py index d55b6538..21e8d279 100644 --- a/verl_omni/workers/engine/fsdp/diffusers_impl.py +++ b/verl_omni/workers/engine/fsdp/diffusers_impl.py @@ -185,8 +185,47 @@ def _init_device_mesh(self): self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + def _build_module_from_registry(self, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: + """Build the model via a registered ``DiffusionModelBase`` subclass. + + Returns ``None`` when the subclass does not provide a custom loader. + Custom loaders bypass ``diffusers.AutoModel``, so engine-level hooks + (attention processors, gradient checkpointing, LoRA, dtype upcast) + may be silently inactive on the returned module. + + TODO: drop this function once the model is integrated into a + first-class engine (``transformers`` / ``diffusers`` / ``veomni``). + """ + from verl_omni.pipelines.model_base import DiffusionModelBase + + model_cls = DiffusionModelBase.get_class(self.model_config) + module = model_cls.build_module(self.model_config, torch_dtype) + if module is None: + return None + + logger.warning( + "Built %s via DiffusionModelBase custom loader; engine-level hooks " + "(attention processors, gradient-checkpointing wrappers, LoRA, " + "dtype upcast) may be partially effective or silently inactive. " + "See the docstring of _build_module_from_registry.", + type(module).__name__, + ) + + module.to(torch_dtype) + if self.model_config.enable_gradient_checkpointing: + enable_checkpointing = getattr(module, "enable_gradient_checkpointing", None) + if callable(enable_checkpointing): + enable_checkpointing() + else: + logger.warning( + "Gradient checkpointing requested, but %s does not implement enable_gradient_checkpointing()", + type(module).__name__, + ) + if not hasattr(module, "can_generate"): + module.can_generate = lambda: False + return module + def _build_module(self): - from diffusers import AutoModel from verl.utils.torch_dtypes import PrecisionType torch_dtype = self.engine_config.model_dtype @@ -197,6 +236,13 @@ def _build_module(self): torch_dtype = PrecisionType.to_dtype(torch_dtype) + module = self._build_module_from_registry(torch_dtype) + if module is not None: + return module + + # Default path: load via diffusers AutoModel + from diffusers import AutoModel + init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): @@ -576,16 +622,16 @@ def prepare_model_inputs(self, micro_batch: TensorDict, step: int): """ latents = micro_batch["all_latents"] timesteps = micro_batch["all_timesteps"] - prompt_embeds = micro_batch["prompt_embeds"] - prompt_embeds_mask = micro_batch["prompt_embeds_mask"] - negative_prompt_embeds = micro_batch["negative_prompt_embeds"] - negative_prompt_embeds_mask = micro_batch["negative_prompt_embeds_mask"] + prompt_embeds = micro_batch.get("prompt_embeds", None) + prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None) + negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) + negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1 - if prompt_embeds.is_nested: + if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) - if sp_size > 1: + if isinstance(prompt_embeds, torch.Tensor) and sp_size > 1: prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: diff --git a/verl_omni/workers/rollout/vllm_rollout/utils.py b/verl_omni/workers/rollout/vllm_rollout/utils.py index 6d419359..6986b4c9 100644 --- a/verl_omni/workers/rollout/vllm_rollout/utils.py +++ b/verl_omni/workers/rollout/vllm_rollout/utils.py @@ -15,7 +15,7 @@ import os import torch -from verl.workers.rollout.vllm_rollout.utils import VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH, set_death_signal +from verl.workers.rollout.vllm_rollout.utils import VLLM_LORA_INT_ID, VLLM_LORA_NAME, VLLM_LORA_PATH from vllm_omni.diffusion.worker.diffusion_worker import CustomPipelineWorkerExtension from verl_omni.utils.vllm_omni import OmniTensorLoRARequest, VLLMOmniHijack @@ -38,8 +38,10 @@ class vLLMOmniColocateWorkerExtension(CustomPipelineWorkerExtension): """ def __new__(cls, **kwargs): - set_death_signal() - + # Do NOT call verl's ``set_death_signal``: ``PR_SET_PDEATHSIG`` is + # thread-scoped, and vllm-omni spawns diffusion workers from a short-lived + # ``ThreadPoolExecutor`` thread, which would SIGKILL them on thread exit. + # ``DiffusionWorker`` already runs as ``daemon=True``. # 1. patch for Lora VLLMOmniHijack.hijack() diff --git a/verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py b/verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py index 3b2187bc..767ddf94 100644 --- a/verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py +++ b/verl_omni/workers/rollout/vllm_rollout/vllm_omni_async_server.py @@ -80,6 +80,10 @@ def _get_override_generation_config(self) -> dict: def _get_engine_kwargs_key(self) -> str: return "vllm_omni" + def _preprocess_engine_kwargs(self, engine_kwargs: dict) -> None: + # No-op: ``deploy_config`` is a vllm-omni CLI flag and must reach the parser. + return + def _get_worker_extension_cls(self) -> str: return "verl_omni.workers.rollout.vllm_rollout.utils.vLLMOmniColocateWorkerExtension" @@ -97,6 +101,17 @@ async def run_server(self, args: argparse.Namespace): engine_args = OmniEngineArgs.from_cli_args(args) engine_args = asdict(engine_args) + # ``deploy_config`` lives on ``OrchestratorArgs``, not ``OmniEngineArgs``, + # so ``from_cli_args`` drops it; forward it manually. + deploy_config = getattr(args, "deploy_config", None) + if deploy_config is not None: + engine_args["deploy_config"] = deploy_config + + # Drop verl's injected ``compilation_config``: re-validation under + # pydantic-strict rejects ``CompilationConfig``'s default ``None`` fields. + # BAGEL's deploy YAML sets ``enforce_eager: true``, so this is a no-op. + engine_args.pop("compilation_config", None) + import_external_libs(self.config.external_lib) pipeline_path = VllmOmniPipelineBase.get_pipeline_path( architecture=self.model_config.architecture, @@ -149,9 +164,12 @@ async def generate( video_data: Optional[list[Any]] = None, negative_prompt_ids: Optional[list[int]] = None, priority: int = 0, + lora_request: Optional[LoRARequest] = None, + lora_scale: float = 1.0, ) -> DiffusionOutput: """Generate sequence with token-in-image-out.""" prompt_ids = normalize_token_ids(prompt_ids) + default_params_list = self.engine.default_sampling_params_list multi_modal_data = {} if image_data is not None: @@ -159,9 +177,8 @@ async def generate( if video_data is not None: multi_modal_data["video"] = video_data - # Add lora request - lora_request = None - if self.lora_as_adapter: + # Add lora request (caller-supplied takes precedence over lora_as_adapter) + if lora_request is None and self.lora_as_adapter: # Make sure we also check that the lora is already loaded in the engine lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras() if lora_loaded: @@ -169,11 +186,14 @@ async def generate( lora_name=VLLM_LORA_NAME, lora_int_id=VLLM_LORA_INT_ID, lora_path=VLLM_LORA_PATH ) - # Build OmniCustomPrompt with pre-tokenized IDs - custom_prompt: OmniCustomPrompt = {"prompt_ids": prompt_ids} + # Build OmniCustomPrompt with pre-tokenized IDs (downstream pipelines read "prompt_token_ids") + custom_prompt: OmniCustomPrompt = {"prompt_token_ids": prompt_ids} + if len(default_params_list) > 1: + custom_prompt["modalities"] = ["image"] if negative_prompt_ids is not None: custom_prompt["negative_prompt_ids"] = negative_prompt_ids if multi_modal_data: + custom_prompt["multi_modal_data"] = multi_modal_data custom_prompt["extra_args"] = {"multi_modal_data": multi_modal_data} # Build OmniDiffusionSamplingParams from the incoming dict @@ -187,13 +207,17 @@ async def generate( sampling_kwargs["extra_args"] = extra_args if lora_request is not None: sampling_kwargs["lora_request"] = lora_request + sampling_kwargs["lora_scale"] = lora_scale diffusion_sampling_params = OmniDiffusionSamplingParams(**sampling_kwargs) + # Build sampling params list: multi-stage models use defaults for non-diffusion stages + sampling_params_list = default_params_list[:-1] + [diffusion_sampling_params] + # Call AsyncOmni.generate() with the correct API generator = self.engine.generate( prompt=custom_prompt, request_id=request_id, - sampling_params_list=[diffusion_sampling_params], + sampling_params_list=sampling_params_list, ) # Get final response @@ -208,27 +232,17 @@ async def generate( mm_output = final_res.custom_output or {} if sampling_params.get("logprobs", False): - all_log_probs = mm_output.get("all_log_probs") - log_probs = all_log_probs[0] if all_log_probs is not None else None + log_probs = mm_output.get("all_log_probs") else: log_probs = None - all_latents = mm_output.get("all_latents") - all_timesteps = mm_output.get("all_timesteps") - prompt_embeds = mm_output.get("prompt_embeds") - prompt_embeds_mask = mm_output.get("prompt_embeds_mask") - negative_prompt_embeds = mm_output.get("negative_prompt_embeds") - negative_prompt_embeds_mask = mm_output.get("negative_prompt_embeds_mask") - extra_fields = { - "all_latents": all_latents[0] if all_latents is not None else None, - "all_timesteps": all_timesteps[0] if all_timesteps is not None else None, - "prompt_embeds": prompt_embeds[0] if prompt_embeds is not None else None, - "prompt_embeds_mask": prompt_embeds_mask[0] if prompt_embeds_mask is not None else None, - "negative_prompt_embeds": negative_prompt_embeds[0] if negative_prompt_embeds is not None else None, - "negative_prompt_embeds_mask": negative_prompt_embeds_mask[0] - if negative_prompt_embeds_mask is not None - else None, + "all_latents": mm_output.get("all_latents"), + "all_timesteps": mm_output.get("all_timesteps"), + "prompt_embeds": mm_output.get("prompt_embeds"), + "prompt_embeds_mask": mm_output.get("prompt_embeds_mask"), + "negative_prompt_embeds": mm_output.get("negative_prompt_embeds"), + "negative_prompt_embeds_mask": mm_output.get("negative_prompt_embeds_mask"), "global_steps": self.global_steps, } diff --git a/verl_omni/workers/utils/padding.py b/verl_omni/workers/utils/padding.py index 9f85b4db..5fb50163 100644 --- a/verl_omni/workers/utils/padding.py +++ b/verl_omni/workers/utils/padding.py @@ -47,7 +47,10 @@ def _to_nested(embeds: torch.Tensor, mask: torch.Tensor): torch.nested.as_nested_tensor(mask_list, layout=torch.jagged), ) - data["prompt_embeds"], data["prompt_embeds_mask"] = _to_nested(data["prompt_embeds"], data["prompt_embeds_mask"]) + if isinstance(data.get("prompt_embeds", None), torch.Tensor): + data["prompt_embeds"], data["prompt_embeds_mask"] = _to_nested( + data["prompt_embeds"], data["prompt_embeds_mask"] + ) if isinstance(data.get("negative_prompt_embeds", None), torch.Tensor): data["negative_prompt_embeds"], data["negative_prompt_embeds_mask"] = _to_nested(