diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7417fe74b55..b789db716cb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,7 @@ repos: - id: check-naming-conventions name: Check naming conventions - entry: sh -c 'fail=0; if grep -rIn --exclude-dir=.git --exclude-dir=.github --exclude-dir=venv --exclude-dir=__pycache__ --exclude=.pre-commit-config.yaml "veRL" .; then echo "Please use verl instead of veRL"; fail=1; fi; if grep -rIn --exclude-dir=.git --exclude-dir=.github --exclude-dir=venv --exclude-dir=__pycache__ --exclude=ascend_sglang_best_practices.rst --exclude=.pre-commit-config.yaml -E "Sglang|sgLang|sglAng|sglaNg|sglanG" .; then echo "Please use SGLang or sglang"; fail=1; fi; exit $fail' + entry: sh -c 'fail=0; if grep -rIn --exclude-dir=.git --exclude-dir=.github --exclude-dir=venv --exclude-dir=.venv --exclude-dir=__pycache__ --exclude=.pre-commit-config.yaml "veRL" .; then echo "Please use verl instead of veRL"; fail=1; fi; if grep -rIn --exclude-dir=.git --exclude-dir=.github --exclude-dir=venv --exclude-dir=.venv --exclude-dir=__pycache__ --exclude=ascend_sglang_best_practices.rst --exclude=.pre-commit-config.yaml -E "Sglang|sgLang|sglAng|sglaNg|sglanG" .; then echo "Please use SGLang or sglang"; fail=1; fi; exit $fail' language: system pass_filenames: false diff --git a/examples/flowgrpo_trainer/bagel_stage_config.yaml b/examples/flowgrpo_trainer/bagel_stage_config.yaml new file mode 100644 index 00000000000..79ece0c8ea8 --- /dev/null +++ b/examples/flowgrpo_trainer/bagel_stage_config.yaml @@ -0,0 +1,25 @@ +# Single-stage Bagel config for FlowGRPO training with colocated workers. + +stage_args: + + - stage_id: 0 + stage_type: diffusion + runtime: + devices: "0" + engine_args: + model_stage: dit + max_num_seqs: 1 + enforce_eager: true + trust_remote_code: true + + final_output: true + final_output_type: image + is_comprehension: false + default_sampling_params: + seed: 52 + +runtime: + enabled: true + defaults: + window_size: -1 + max_inflight: 1 diff --git a/examples/flowgrpo_trainer/diffusers/bagel.py b/examples/flowgrpo_trainer/diffusers/bagel.py new file mode 100644 index 00000000000..4b97ebd3c47 --- /dev/null +++ b/examples/flowgrpo_trainer/diffusers/bagel.py @@ -0,0 +1,196 @@ +"""BAGEL (MoT) diffusion model implementation for FlowGRPO training. + +Registers as ``OmniBagelForConditionalGeneration`` so the FSDP engine +can load and train the model via the DiffusionModelBase registry. + +Key differences from standard diffusion models (e.g. QwenImage): + * BAGEL is a *Mixture-of-Thought* transformer that processes text token + IDs and noisy latent patches in a single forward pass (no separate + text encoder). + * ``prompt_embeds`` are not used. Instead, the raw prompt token IDs + (available as ``micro_batch["prompts"]``) are passed directly to the + model as ``text_token_ids``. + * CFG uses a 3-branch scheme during rollout, but for FSDP training + (computing log-probs of the rollout trajectory) only the conditional + forward is needed. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import numpy as np +import torch +from tensordict import TensorDict + +from verl.models.diffusers_model import DiffusionModelBase +from verl.utils import tensordict_utils as tu +from verl.utils.device import get_device_name +from verl.workers.config import DiffusionModelConfig + +from ..scheduler import FlowMatchSDEDiscreteScheduler +from .bagel_model import BagelForTraining, get_flattened_position_ids + +logger = logging.getLogger(__name__) + +TIMESTEP_SHIFT = 3.0 # must match BagelPipeline.forward() hardcoded value + + +@DiffusionModelBase.register("OmniBagelForConditionalGeneration") +class BagelDiffusion(DiffusionModelBase): + """DiffusionModelBase wrapper for BagelForTraining (MoT).""" + + # ------------------------------------------------------------------ + # Custom model loading (BAGEL can't be loaded via diffusers.AutoModel) + # ------------------------------------------------------------------ + + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype): + logger.info("Loading BagelForTraining from %s", model_config.local_path) + module = BagelForTraining.from_pretrained( + model_config.local_path, torch_dtype=torch_dtype + ) + return module + + # ------------------------------------------------------------------ + # Scheduler + # ------------------------------------------------------------------ + + @classmethod + def build_scheduler(cls, model_config: DiffusionModelConfig): + scheduler = FlowMatchSDEDiscreteScheduler() + cls.set_timesteps(scheduler, model_config, get_device_name()) + return scheduler + + @classmethod + def set_timesteps(cls, scheduler: FlowMatchSDEDiscreteScheduler, model_config: DiffusionModelConfig, device: str): + num_inference_steps = model_config.num_inference_steps + t = np.linspace(1, 0, num_inference_steps) + t_shifted = TIMESTEP_SHIFT * t / (1 + (TIMESTEP_SHIFT - 1) * t) + sigmas = t_shifted[:-1].tolist() + + scheduler.set_shift(1.0) # identity — sigmas already shifted + scheduler.set_timesteps(sigmas=sigmas) + scheduler.set_begin_index(0) + + # ------------------------------------------------------------------ + # Prepare model inputs + # ------------------------------------------------------------------ + + @classmethod + def _get_latent_pos_ids(cls, model_config: DiffusionModelConfig, module, device) -> torch.Tensor: + """Compute latent position IDs from model config / image dimensions.""" + config = module.config + img_h = model_config.height // (config.latent_patch_size * config.vae_downsample) + img_w = model_config.width // (config.latent_patch_size * config.vae_downsample) + # Clamp to max_latent_size + img_h = min(img_h, config.max_latent_size) + img_w = min(img_w, config.max_latent_size) + latent_ds = config.latent_patch_size * config.vae_downsample + H_px = img_h * latent_ds + W_px = img_w * latent_ds + pos_ids = get_flattened_position_ids( + H_px, W_px, latent_ds, config.max_latent_size, + ) + return pos_ids.to(device) + + @classmethod + def prepare_model_inputs( + cls, + module, + model_config: DiffusionModelConfig, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + micro_batch: TensorDict, + step: int, + ) -> tuple[dict, dict]: + B = latents.shape[0] + device = latents.device + + hidden_states = latents[:, step] + timestep = timesteps[:, step] + + # Extract text token IDs from prompt data + prompts = micro_batch["prompts"] # (B, L_prompt) padded + attention_mask = micro_batch["attention_mask"] # (B, L_prompt) + + # Build per-sample text_token_ids (remove padding) + text_token_ids_list = [] + for i in range(B): + mask = attention_mask[i].bool() + ids = prompts[i][mask] + text_token_ids_list.append(ids) + + # Pad to same length within batch + max_text_len = max(ids.shape[0] for ids in text_token_ids_list) + text_token_ids = torch.zeros(B, max_text_len, dtype=torch.long, device=device) + for i, ids in enumerate(text_token_ids_list): + text_token_ids[i, :ids.shape[0]] = ids + + # Compute latent position IDs + latent_pos_ids = cls._get_latent_pos_ids(model_config, module, device) + latent_pos_ids = latent_pos_ids.unsqueeze(0).expand(B, -1) + + model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "text_token_ids": text_token_ids, + "latent_pos_ids": latent_pos_ids, + } + + # For BAGEL, unconditional pass uses text_token_ids=None + negative_model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "text_token_ids": None, + "latent_pos_ids": latent_pos_ids, + } + + return model_inputs, negative_model_inputs + + # ------------------------------------------------------------------ + # Forward + scheduler step + # ------------------------------------------------------------------ + + @classmethod + def forward_and_sample_previous_step( + cls, + module, + scheduler: FlowMatchSDEDiscreteScheduler, + model_config: DiffusionModelConfig, + model_inputs: dict[str, torch.Tensor], + negative_model_inputs: Optional[dict[str, torch.Tensor]], + scheduler_inputs: Optional[TensorDict | dict[str, torch.Tensor]], + step: int, + ): + assert scheduler_inputs is not None + latents = scheduler_inputs["all_latents"] + timesteps = scheduler_inputs["all_timesteps"] + + noise_pred = module(**model_inputs)[0] + + # CFG during training (if configured) + true_cfg_scale = model_config.extra_configs.get("true_cfg_scale", 1.0) + if true_cfg_scale > 1.0: + assert negative_model_inputs is not None + neg_noise_pred = module(**negative_model_inputs)[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + _, log_prob, prev_sample_mean, std_dev_t = scheduler.sample_previous_step( + sample=latents[:, step].float(), + model_output=noise_pred.float(), + timestep=timesteps[:, step], + noise_level=model_config.extra_configs.get("noise_level", None), + prev_sample=latents[:, step + 1].float(), + sde_type=model_config.extra_configs.get("sde_type", None), + return_logprobs=True, + ) + return log_prob, prev_sample_mean, std_dev_t diff --git a/examples/flowgrpo_trainer/diffusers/bagel_model.py b/examples/flowgrpo_trainer/diffusers/bagel_model.py new file mode 100644 index 00000000000..e8aa3b63227 --- /dev/null +++ b/examples/flowgrpo_trainer/diffusers/bagel_model.py @@ -0,0 +1,817 @@ +"""BagelForTraining – FSDP-compatible BAGEL MoT module for flow-matching training. + +Ported from vllm-omni/BAGEL with the following correctness-critical details: + * MoT (Mixture-of-Thought): dual pathways for text vs generation tokens + * start_of_image / end_of_image boundary tokens are required + * All latent tokens share ONE RoPE position (spatial via 2-D sincos embed) + * QK-norm + RoPE in float32; cast to bfloat16 only for SDPA + * Attention mask: text-context is causal & cannot see image region + +Dependencies: torch, numpy, safetensors, einops, transformers (AutoTokenizer) +NO dependency on vllm or vllm-omni. +""" + +from __future__ import annotations + +import json +import math +import os +from dataclasses import dataclass +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor + + +# =================================================================== +# Config +# =================================================================== + +@dataclass +class BagelTrainingConfig: + hidden_size: int = 3584 + intermediate_size: int = 18944 + num_hidden_layers: int = 28 + num_attention_heads: int = 28 + num_key_value_heads: int = 4 + vocab_size: int = 152064 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1_000_000.0 + max_position_embeddings: int = 32768 + # Bagel-specific + latent_patch_size: int = 2 + max_latent_size: int = 32 + latent_channel: int = 16 + vae_downsample: int = 8 + start_of_image_id: int = 151652 # <|vision_start|> + end_of_image_id: int = 151653 # <|vision_end|> + + @property + def head_dim(self) -> int: + return self.hidden_size // self.num_attention_heads + + @property + def patch_latent_dim(self) -> int: + return self.latent_patch_size ** 2 * self.latent_channel + + def save_pretrained(self, save_directory: str): + """Save config as JSON (compatible with diffusers checkpoint manager).""" + from dataclasses import asdict + output_path = os.path.join(save_directory, "config.json") + os.makedirs(save_directory, exist_ok=True) + with open(output_path, "w") as f: + json.dump(asdict(self), f, indent=4, sort_keys=True) + + @classmethod + def from_model_path(cls, model_path: str) -> "BagelTrainingConfig": + cfg_path = os.path.join(model_path, "config.json") + with open(cfg_path) as f: + root_cfg = json.load(f) + llm = root_cfg.get("llm_config", {}) + vae = root_cfg.get("vae_config", {}) + return cls( + hidden_size=llm.get("hidden_size", 3584), + intermediate_size=llm.get("intermediate_size", 18944), + num_hidden_layers=llm.get("num_hidden_layers", 28), + num_attention_heads=llm.get("num_attention_heads", 28), + num_key_value_heads=llm.get("num_key_value_heads", 4), + vocab_size=llm.get("vocab_size", 152064), + rms_norm_eps=llm.get("rms_norm_eps", 1e-6), + rope_theta=llm.get("rope_theta", 1_000_000.0), + max_position_embeddings=llm.get("max_position_embeddings", 32768), + latent_patch_size=root_cfg.get("latent_patch_size", 2), + max_latent_size=root_cfg.get("max_latent_size", 32), + latent_channel=vae.get("z_channels", 16), + vae_downsample=vae.get("downsample", 8), + ) + + +# =================================================================== +# VAE AutoEncoder (from FLUX / BAGEL, Apache-2.0) +# =================================================================== + +@dataclass +class AutoEncoderParams: + resolution: int = 256 + in_channels: int = 3 + downsample: int = 8 + ch: int = 128 + out_ch: int = 3 + ch_mult: list[int] | tuple[int, ...] = (1, 2, 4, 4) + num_res_blocks: int = 2 + z_channels: int = 16 + scale_factor: float = 0.3611 + shift_factor: float = 0.1159 + + +def _swish(x: Tensor) -> Tensor: + return x * torch.sigmoid(x) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) + self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) + + def forward(self, x: Tensor) -> Tensor: + h_ = self.norm(x) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() + k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() + v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() + h_ = F.scaled_dot_product_attention(q, k, v) + h_ = rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + return x + self.proj_out(h_) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x: Tensor) -> Tensor: + h = self.norm1(x) + h = _swish(h) + h = self.conv1(h) + h = self.norm2(h) + h = _swish(h) + h = self.conv2(h) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + h + + +class _Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x: Tensor) -> Tensor: + x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0) + return self.conv(x) + + +class _Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + return self.conv(x) + + +class Encoder(nn.Module): + def __init__(self, resolution: int, in_channels: int, ch: int, + ch_mult: list[int], num_res_blocks: int, z_channels: int): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.conv_in = nn.Conv2d(in_channels, ch, kernel_size=3, stride=1, padding=1) + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + block_in = ch + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = _Downsample(block_in) + self.down.append(down) + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x: Tensor) -> Tensor: + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + h = self.norm_out(h) + h = _swish(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, ch: int, out_ch: int, ch_mult: list[int], + num_res_blocks: int, in_channels: int, resolution: int, z_channels: int): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = _Upsample(block_in) + self.up.insert(0, up) + self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def forward(self, z: Tensor) -> Tensor: + h = self.conv_in(z) + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + h = self.norm_out(h) + h = _swish(h) + h = self.conv_out(h) + return h + + +class DiagonalGaussian(nn.Module): + def __init__(self, sample: bool = True, chunk_dim: int = 1): + super().__init__() + self.sample = sample + self.chunk_dim = chunk_dim + + def forward(self, z: Tensor) -> Tensor: + mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) + if self.sample: + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, in_channels=params.in_channels, + ch=params.ch, ch_mult=list(params.ch_mult), + num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, in_channels=params.in_channels, + ch=params.ch, out_ch=params.out_ch, ch_mult=list(params.ch_mult), + num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def encode(self, x: Tensor) -> Tensor: + z = self.reg(self.encoder(x)) + return self.scale_factor * (z - self.shift_factor) + + def decode(self, z: Tensor) -> Tensor: + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def forward(self, x: Tensor) -> Tensor: + return self.decode(self.encode(x)) + + +def load_ae(path: str) -> tuple[AutoEncoder, AutoEncoderParams]: + """Load VAE autoencoder from a safetensors checkpoint.""" + params = AutoEncoderParams() + ae = AutoEncoder(params) + if path is not None: + from safetensors.torch import load_file + sd = load_file(path) + missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) + if missing: + print(f"VAE load: {len(missing)} missing keys") + if unexpected: + print(f"VAE load: {len(unexpected)} unexpected keys") + return ae, params + + +# =================================================================== +# Tokenizer & data utilities (replaces BAGEL/data/data_utils.py) +# =================================================================== + +def load_tokenizer(model_path: str): + """Load tokenizer with special tokens for BAGEL using transformers.AutoTokenizer.""" + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + all_special = set() + for v in tokenizer.special_tokens_map.values(): + if isinstance(v, str): + all_special.add(v) + elif isinstance(v, list): + all_special.update(v) + + new_tokens = [] + for t in ['<|im_start|>', '<|im_end|>', '<|vision_start|>', '<|vision_end|>']: + if t not in all_special and t not in tokenizer.get_vocab(): + new_tokens.append(t) + if new_tokens: + tokenizer.add_tokens(new_tokens) + + new_token_ids = { + 'bos_token_id': tokenizer.convert_tokens_to_ids('<|im_start|>'), + 'eos_token_id': tokenizer.convert_tokens_to_ids('<|im_end|>'), + 'start_of_image': tokenizer.convert_tokens_to_ids('<|vision_start|>'), + 'end_of_image': tokenizer.convert_tokens_to_ids('<|vision_end|>'), + } + return tokenizer, new_token_ids + + +def get_flattened_position_ids(img_h: int, img_w: int, + patch_size: int, max_num_patches_per_side: int) -> torch.Tensor: + """Compute flattened 2-D position IDs for latent patches (extrapolate mode).""" + num_patches_h = img_h // patch_size + num_patches_w = img_w // patch_size + coords_h = torch.arange(0, num_patches_h) + coords_w = torch.arange(0, num_patches_w) + pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten() + return pos_ids + + +# =================================================================== +# Transformer building blocks +# =================================================================== + +class RMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + input_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.eps) + return self.weight * x.to(input_dtype) + + +class BagelMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +# =================================================================== +# RoPE helpers +# =================================================================== + +def _rotate_half(x: Tensor) -> Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_emb(q, k, cos, sin): + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim: int, max_position_embeddings: int = 32768, + theta: float = 1_000_000.0): + super().__init__() + inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, position_ids: Tensor): + freqs = torch.einsum("bi,j->bij", position_ids.float(), self.inv_freq.to(position_ids.device)) + emb = torch.cat([freqs, freqs], dim=-1) + return emb.cos(), emb.sin() + + +# =================================================================== +# MoT Attention & Layer +# =================================================================== + +class BagelMoTAttention(nn.Module): + """MoT attention with separate standard and generation projections.""" + + def __init__(self, config: BagelTrainingConfig): + super().__init__() + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_proj_moe_gen = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj_moe_gen = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.v_proj_moe_gen = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=True) + self.o_proj_moe_gen = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm_moe_gen = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: Tensor, + cos: Tensor, + sin: Tensor, + text_mask: Tensor, + latent_mask: Tensor, + L_ctx: int = 0, + ) -> Tensor: + B, L, _ = hidden_states.shape + text_idx = text_mask.nonzero(as_tuple=True) + latent_idx = latent_mask.nonzero(as_tuple=True) + + q = hidden_states.new_zeros(B, L, self.num_heads * self.head_dim) + k = hidden_states.new_zeros(B, L, self.num_kv_heads * self.head_dim) + v = hidden_states.new_zeros(B, L, self.num_kv_heads * self.head_dim) + + text_hs = hidden_states[text_idx] + q[text_idx] = self.q_proj(text_hs) + k[text_idx] = self.k_proj(text_hs) + v[text_idx] = self.v_proj(text_hs) + + latent_hs = hidden_states[latent_idx] + q[latent_idx] = self.q_proj_moe_gen(latent_hs) + k[latent_idx] = self.k_proj_moe_gen(latent_hs) + v[latent_idx] = self.v_proj_moe_gen(latent_hs) + + q = q.view(B, L, self.num_heads, self.head_dim) + k = k.view(B, L, self.num_kv_heads, self.head_dim) + v = v.view(B, L, self.num_kv_heads, self.head_dim) + + q = q.to(torch.float32) + k = k.to(torch.float32) + q_normed = q.new_zeros(q.shape) + k_normed = k.new_zeros(k.shape) + q_normed[text_idx] = self.q_norm(q[text_idx]) + k_normed[text_idx] = self.k_norm(k[text_idx]) + q_normed[latent_idx] = self.q_norm_moe_gen(q[latent_idx]) + k_normed[latent_idx] = self.k_norm_moe_gen(k[latent_idx]) + + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) + q_normed, k_normed = _apply_rotary_emb(q_normed, k_normed, cos, sin) + + q_normed = q_normed.to(torch.bfloat16) + k_normed = k_normed.to(torch.bfloat16) + v = v.to(torch.bfloat16) + + if self.num_kv_heads < self.num_heads: + rep = self.num_heads // self.num_kv_heads + k_normed = k_normed.unsqueeze(3).expand(-1, -1, -1, rep, -1).reshape(B, L, self.num_heads, self.head_dim) + v = v.unsqueeze(3).expand(-1, -1, -1, rep, -1).reshape(B, L, self.num_heads, self.head_dim) + + # Split attention: no boolean mask → SDPA uses flash backend + # Original vllm-omni uses flash_attn_varlen_func; matching that + # requires avoiding the "math" fallback that boolean masks trigger. + # Text tokens: causal self-attention (only see prior text) + # Image tokens (soi/latent/eoi): full attention to everything + q_normed = q_normed.transpose(1, 2) # (B, H, L, D) + k_normed = k_normed.transpose(1, 2) + v = v.transpose(1, 2) + + if L_ctx > 0: + # Text self-attention (causal, flash backend) + text_out = F.scaled_dot_product_attention( + q_normed[:, :, :L_ctx], + k_normed[:, :, :L_ctx], + v[:, :, :L_ctx], + is_causal=True, + ) + # Image attention to full sequence (no mask, flash backend) + img_out = F.scaled_dot_product_attention( + q_normed[:, :, L_ctx:], + k_normed, + v, + is_causal=False, + ) + attn_out = torch.cat([text_out, img_out], dim=2) + else: + attn_out = F.scaled_dot_product_attention( + q_normed, k_normed, v, is_causal=False, + ) + + attn_out = attn_out.transpose(1, 2).contiguous().view(B, L, -1) + + out = hidden_states.new_zeros(B, L, self.hidden_size) + out[text_idx] = self.o_proj(attn_out[text_idx]) + out[latent_idx] = self.o_proj_moe_gen(attn_out[latent_idx]) + return out + + +class BagelMoTLayer(nn.Module): + def __init__(self, config: BagelTrainingConfig): + super().__init__() + self.self_attn = BagelMoTAttention(config) + self.mlp = BagelMLP(config.hidden_size, config.intermediate_size) + self.mlp_moe_gen = BagelMLP(config.hidden_size, config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: Tensor, + cos: Tensor, + sin: Tensor, + text_mask: Tensor, + latent_mask: Tensor, + L_ctx: int = 0, + ) -> Tensor: + text_idx = text_mask.nonzero(as_tuple=True) + latent_idx = latent_mask.nonzero(as_tuple=True) + + normed = hidden_states.new_zeros(hidden_states.shape) + normed[text_idx] = self.input_layernorm(hidden_states[text_idx]) + normed[latent_idx] = self.input_layernorm_moe_gen(hidden_states[latent_idx]) + + attn_out = self.self_attn(normed, cos, sin, text_mask, latent_mask, L_ctx) + hidden_states = hidden_states + attn_out + + residual = hidden_states + mlp_out = hidden_states.new_zeros(hidden_states.shape) + mlp_out[text_idx] = self.mlp(self.post_attention_layernorm(hidden_states[text_idx])) + mlp_out[latent_idx] = self.mlp_moe_gen( + self.post_attention_layernorm_moe_gen(hidden_states[latent_idx]) + ) + hidden_states = residual + mlp_out + return hidden_states + + +# =================================================================== +# Position embedding helpers +# =================================================================== + +def _get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> np.ndarray: + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + return np.concatenate([np.sin(out), np.cos(out)], axis=1) + + +def _get_2d_sincos_pos_embed(embed_dim: int, grid_size: int) -> np.ndarray: + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) + grid = np.stack(grid, axis=0).reshape(2, 1, grid_size, grid_size) + emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) + emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) + return np.concatenate([emb_h, emb_w], axis=1) + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size: int, freq_dim: int = 256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(freq_dim, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.freq_dim = freq_dim + + def forward(self, t: Tensor) -> Tensor: + half = self.freq_dim // 2 + freqs = torch.exp( + -math.log(10000) * torch.arange(half, dtype=torch.float32, device=t.device) / half + ) + args = t[:, None].float() * freqs[None] + emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + emb = emb.to(self.mlp[0].weight.dtype) + return self.mlp(emb) + + +class PositionEmbedding(nn.Module): + def __init__(self, max_num_patch_per_side: int, hidden_size: int): + super().__init__() + pos_embed = _get_2d_sincos_pos_embed(hidden_size, max_num_patch_per_side) + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed).float(), requires_grad=False + ) + + def forward(self, position_ids: Tensor) -> Tensor: + return self.pos_embed[position_ids] + + +# =================================================================== +# Main module: BagelForTraining +# =================================================================== + +class BagelForTraining(nn.Module): + """Standalone Bagel MoT module for FlowGRPO FSDP training. + + Forward signature: + hidden_states: (B, L_latent, patch_latent_dim) — noisy latent patches + timestep: (B,) — diffusion timestep scalars + text_token_ids: (B, L_text) — tokenized prompt IDs (with bos/eos) + latent_pos_ids: (B, L_latent) — 2-D position indices for latent patches + """ + + def __init__(self, config: BagelTrainingConfig): + super().__init__() + self.config = config + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([BagelMoTLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_moe_gen = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = RotaryEmbedding(config.head_dim, theta=config.rope_theta) + + self.time_embedder = TimestepEmbedder(config.hidden_size) + self.vae2llm = nn.Linear(config.patch_latent_dim, config.hidden_size) + self.llm2vae = nn.Linear(config.hidden_size, config.patch_latent_dim) + self.latent_pos_embed = PositionEmbedding(config.max_latent_size, config.hidden_size) + + def forward( + self, + hidden_states: Tensor, + timestep: Tensor, + text_token_ids: Optional[Tensor], + latent_pos_ids: Tensor, + **kwargs, + ) -> tuple[Tensor]: + """Forward pass. + + When text_token_ids is None the sequence is [soi, latent, eoi] only + (no text context). This is used for the CFG unconditional pass. + """ + B = hidden_states.shape[0] + L_latent = hidden_states.shape[1] + dev = hidden_states.device + + # 1. Embed text context + if text_token_ids is not None: + text_embeds = self.embed_tokens(text_token_ids) + L_ctx = text_embeds.shape[1] + else: + L_ctx = 0 + + # 2. SOI / EOI boundary tokens + soi_ids = torch.full((B, 1), self.config.start_of_image_id, dtype=torch.long, device=dev) + eoi_ids = torch.full((B, 1), self.config.end_of_image_id, dtype=torch.long, device=dev) + soi_emb = self.embed_tokens(soi_ids) + eoi_emb = self.embed_tokens(eoi_ids) + + # 3. Latent projection + t_emb = self.time_embedder(timestep) + pos_emb = self.latent_pos_embed(latent_pos_ids) + latent_embeds = self.vae2llm(hidden_states) + t_emb.unsqueeze(1) + pos_emb + latent_embeds = latent_embeds.to(soi_emb.dtype) + + # 4. Sequence: [text?, soi, latent_0..N, eoi] + L_total = L_ctx + 1 + L_latent + 1 + if L_ctx > 0: + sequence = torch.cat([text_embeds, soi_emb, latent_embeds, eoi_emb], dim=1) + else: + sequence = torch.cat([soi_emb, latent_embeds, eoi_emb], dim=1) + + # 5. MoT routing masks + # text pathway: text_ctx + soi + eoi + # gen pathway: latent tokens only + text_mask = torch.zeros(B, L_total, dtype=torch.bool, device=dev) + text_mask[:, : L_ctx + 1] = True # text + soi + text_mask[:, -1] = True # eoi + latent_mask = ~text_mask + + # 6. RoPE positions + if L_ctx > 0: + ctx_pos = torch.arange(L_ctx, device=dev) + img_pos = ctx_pos.new_full((1 + L_latent + 1,), L_ctx) + position_ids = torch.cat([ctx_pos, img_pos]).unsqueeze(0).expand(B, -1) + else: + position_ids = torch.zeros(1, L_total, dtype=torch.long, device=dev).expand(B, -1) + cos, sin = self.rotary_emb(position_ids) + + # 7. Transformer layers (split attention: text causal + image full) + for layer in self.layers: + sequence = layer(sequence, cos, sin, text_mask, latent_mask, L_ctx) + + # 8. Final norm with MoT routing + normed = sequence.new_zeros(sequence.shape) + t_idx = text_mask.nonzero(as_tuple=True) + l_idx = latent_mask.nonzero(as_tuple=True) + normed[t_idx] = self.norm(sequence[t_idx]) + normed[l_idx] = self.norm_moe_gen(sequence[l_idx]) + + # 9. Extract latent output + latent_output = normed[:, L_ctx + 1 : L_ctx + 1 + L_latent, :] + velocity = self.llm2vae(latent_output) + + return (velocity,) + + # ------------------------------------------------------------------ + # PEFT / LoRA compatibility + # ------------------------------------------------------------------ + + def add_adapter(self, adapter_config, adapter_name: str = "default"): + """Add a PEFT LoRA adapter (matches diffusers.ModelMixin API).""" + from peft import inject_adapter_in_model + + inject_adapter_in_model(adapter_config, self, adapter_name) + + # ------------------------------------------------------------------ + # Checkpoint loading + # ------------------------------------------------------------------ + + @classmethod + def from_pretrained(cls, model_path: str, torch_dtype=torch.bfloat16) -> "BagelForTraining": + config = BagelTrainingConfig.from_model_path(model_path) + ckpt_path = os.path.join(model_path, "ema.safetensors") + from safetensors.torch import load_file + state_dict = load_file(ckpt_path) + + if "latent_pos_embed.pos_embed" in state_dict: + actual_len = state_dict["latent_pos_embed.pos_embed"].shape[0] + grid = int(actual_len ** 0.5) + if grid * grid == actual_len and grid != config.max_latent_size: + config.max_latent_size = grid + + model = cls(config) + mapped = _map_checkpoint_to_training(state_dict, config) + missing, unexpected = model.load_state_dict(mapped, strict=False) + if missing: + import logging + logging.getLogger(__name__).warning( + f"Missing keys when loading BagelForTraining: {len(missing)} keys" + ) + + model = model.to(torch_dtype) + return model + + +def _map_checkpoint_to_training( + state_dict: dict[str, Tensor], config: BagelTrainingConfig +) -> dict: + """Map ema.safetensors keys to BagelForTraining parameter names.""" + mapped: dict[str, Tensor] = {} + for src_key, tensor in state_dict.items(): + dst_key: str | None = None + if src_key.startswith("language_model.model."): + dst_key = src_key[len("language_model.model."):] + elif src_key.startswith("language_model."): + continue + elif src_key.startswith(("time_embedder.", "vae2llm.", "llm2vae.", "latent_pos_embed.")): + dst_key = src_key + if dst_key is not None: + mapped[dst_key] = tensor + return mapped diff --git a/examples/flowgrpo_trainer/prepare_ocr_data.py b/examples/flowgrpo_trainer/prepare_ocr_data.py new file mode 100644 index 00000000000..60a8a1369b1 --- /dev/null +++ b/examples/flowgrpo_trainer/prepare_ocr_data.py @@ -0,0 +1,136 @@ +"""Generate a simple OCR training dataset for FlowGRPO text-rendering experiments. + +Produces train.parquet and test.parquet with prompts asking the model to +generate images containing specific text. The ground_truth field stores the +expected text so the reward function can compute an OCR accuracy score. + +Usage: + python examples/flowgrpo_trainer/prepare_ocr_data.py \ + --output-dir ~/data/ocr \ + --train-size 500 \ + --test-size 50 +""" + +import argparse +import random +import string +from pathlib import Path + +import pandas as pd + + +SYSTEM_PROMPT = ( + "Describe the image by detailing the color, shape, size, " + "texture, quantity, text, spatial relationships of the objects and background:" +) + +TEMPLATES_WORD = [ + "Create a picture with the word '{text}' written on it.", +] + +TEMPLATES_PHRASE = [ + "Create a picture with the sentence '{text}' written on it.", +] + +# Simple words/phrases of varying difficulty +WORD_POOLS = { + "single_word": [ + "Hello", "World", "Python", "Design", "Future", "Create", "Vision", + "Magic", "Light", "Dream", "Ocean", "Storm", "Cloud", "River", + "Apple", "Music", "Dance", "Focus", "Power", "Speed", "Brain", + "Space", "Earth", "Tower", "Crown", "Flame", "Sword", "Heart", + "Stone", "Pearl", "Tiger", "Eagle", "Brave", "Happy", "Lucky", + "Smart", "Fresh", "Quiet", "Vivid", "Solid", "Rapid", "Sharp", + "Pixel", "Unity", "Lunar", "Solar", "Coral", "Atlas", "Prism", + ], + "number": [ + "2024", "1234", "42", "100", "3.14", "007", "2048", "365", + "99", "512", "1024", "8080", "404", "200", "1337", "2025", + ], + "short_phrase": [ + "Hello World", "Good Morning", "Open Source", "Deep Learning", + "Keep Going", "Stay Calm", "Think Big", "Game Over", + "No Limits", "Be Bold", "Try Again", "Well Done", + "New York", "San Jose", "Big Data", "Red Moon", + "Ice Cold", "Sky High", "Top Down", "Fast Lane", + ], + "mixed_case": [ + "PyTorch", "GitHub", "OpenAI", "DevOps", "TypeScript", + "YouTube", "MacBook", "iPhone", "LinkedIn", "TikTok", + "ChatGPT", "WiFi", "JavaScript", "README", "HuggingFace", + ], +} + + +def random_word(rng: random.Random) -> str: + pool_name = rng.choice(list(WORD_POOLS.keys())) + return rng.choice(WORD_POOLS[pool_name]) + + +def random_alphanum(rng: random.Random, length: int = None) -> str: + if length is None: + length = rng.randint(3, 8) + chars = string.ascii_letters + string.digits + return "".join(rng.choice(chars) for _ in range(length)) + + +def generate_samples(n: int, seed: int = 42) -> list[dict]: + rng = random.Random(seed) + samples = [] + for _ in range(n): + if rng.random() < 0.8: + text = random_word(rng) + else: + text = random_alphanum(rng) + + is_phrase = " " in text + template = rng.choice(TEMPLATES_PHRASE if is_phrase else TEMPLATES_WORD) + prompt_text = template.format(text=text) + + sample = { + "data_source": "ocr", + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt_text}, + ], + "negative_prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": " "}, + ], + "reward_model": {"style": "rule", "ground_truth": text}, + } + samples.append(sample) + return samples + + +def main(): + parser = argparse.ArgumentParser(description="Generate OCR training data for FlowGRPO") + parser.add_argument("--output-dir", type=Path, default=Path.home() / "data" / "ocr") + parser.add_argument("--train-size", type=int, default=500) + parser.add_argument("--test-size", type=int, default=50) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + args.output_dir.mkdir(parents=True, exist_ok=True) + + train_samples = generate_samples(args.train_size, seed=args.seed) + test_samples = generate_samples(args.test_size, seed=args.seed + 1) + + train_df = pd.DataFrame(train_samples) + test_df = pd.DataFrame(test_samples) + + train_path = args.output_dir / "train.parquet" + test_path = args.output_dir / "test.parquet" + + train_df.to_parquet(train_path, index=False) + test_df.to_parquet(test_path, index=False) + + print(f"Train: {len(train_df)} samples -> {train_path}") + print(f"Test: {len(test_df)} samples -> {test_path}") + print(f"\nSample entry:") + print(f" prompt: {train_samples[0]['prompt']}") + print(f" ground_truth: {train_samples[0]['reward_model']['ground_truth']}") + + +if __name__ == "__main__": + main() diff --git a/examples/flowgrpo_trainer/reward_fn.py b/examples/flowgrpo_trainer/reward_fn.py index a9f6bcf4a4a..ab3b1d720f0 100644 --- a/examples/flowgrpo_trainer/reward_fn.py +++ b/examples/flowgrpo_trainer/reward_fn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import os from typing import Optional @@ -27,14 +28,20 @@ async def chat_complete(router_address: str, chat_complete_request: dict): url = f"http://{router_address}/v1/chat/completions" try: - timeout = aiohttp.ClientTimeout(total=None) + timeout = aiohttp.ClientTimeout(total=120) session = aiohttp.ClientSession(timeout=timeout) async with session.post(url, json=chat_complete_request) as resp: output = await resp.text() + if not output or not output.strip(): + return None output = json.loads(output) return ChatCompletion(**output) + except (json.JSONDecodeError, aiohttp.ClientError, asyncio.TimeoutError) as e: + print(f"[reward_fn] chat_complete failed: {type(e).__name__}: {e}") + return None except Exception as e: - raise e + print(f"[reward_fn] chat_complete unexpected error: {type(e).__name__}: {e}") + return None finally: await session.close() @@ -127,7 +134,11 @@ async def compute_score_ocr( router_address=reward_router_address, chat_complete_request=chat_complete_request, ) + if result is None or not result.choices: + return 0.0 grm_response = result.choices[0].message.content + if not grm_response: + return 0.0 # compute OCR score text = grm_response diff --git a/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh b/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh new file mode 100644 index 00000000000..e9c586ad1c7 --- /dev/null +++ b/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh @@ -0,0 +1,101 @@ +# Bagel LoRA RL, vllm_omni rollout (FlowGRPO) +# +# Prerequisites: +# 1. A Bagel model (e.g. BAGEL-7B-MoT) at $BAGEL_MODEL_PATH +# 2. A stage config YAML at $BAGEL_STAGE_CONFIG for vllm-omni +# 3. DiffusionModelBase registered as "OmniBagelForConditionalGeneration" +# at examples/flowgrpo_trainer/diffusers/bagel.py +# 4. A reward VLM model at $REWARD_MODEL_PATH +# 5. OCR training data at $OCR_TRAIN_PATH / $OCR_TEST_PATH +# (generate via: python examples/flowgrpo_trainer/prepare_ocr_data.py) +# +# Usage: +# export BAGEL_MODEL_PATH=/path/to/BAGEL-7B-MoT +# export REWARD_MODEL_PATH=/path/to/Qwen3-VL-8B-Instruct +# bash examples/flowgrpo_trainer/run_bagel_flowgrpo.sh +# +# # Override any param via CLI: +# bash examples/flowgrpo_trainer/run_bagel_flowgrpo.sh trainer.n_gpus_per_node=8 + +set -x + +# --------------- Paths (override via environment) --------------- +BAGEL_MODEL_PATH=${BAGEL_MODEL_PATH:-$HOME/models/BAGEL-7B-MoT} +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BAGEL_STAGE_CONFIG=${BAGEL_STAGE_CONFIG:-$SCRIPT_DIR/bagel_stage_config.yaml} + +REWARD_MODEL_PATH=${REWARD_MODEL_PATH:-$HOME/models/Qwen3-VL-8B-Instruct} + +ocr_train_path=${OCR_TRAIN_PATH:-$HOME/data/ocr/train.parquet} +ocr_test_path=${OCR_TEST_PATH:-$HOME/data/ocr/test.parquet} + +ENGINE=vllm_omni +REWARD_ENGINE=vllm + +reward_path=examples/flowgrpo_trainer/reward_fn.py + +python3 -m verl.trainer.main_flowgrpo \ + algorithm.adv_estimator=flow_grpo \ + data.train_files=$ocr_train_path \ + data.val_files=$ocr_test_path \ + data.train_batch_size=16 \ + data.max_prompt_length=256 \ + data.trust_remote_code=True \ + actor_rollout_ref.model.path=$BAGEL_MODEL_PATH \ + actor_rollout_ref.model.tokenizer_path=$BAGEL_MODEL_PATH \ + +actor_rollout_ref.model.architecture=OmniBagelForConditionalGeneration \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.model.external_lib="examples.flowgrpo_trainer.diffusers.bagel" \ + actor_rollout_ref.model.height=512 \ + actor_rollout_ref.model.width=512 \ + actor_rollout_ref.model.num_inference_steps=15 \ + actor_rollout_ref.model.lora_rank=64 \ + actor_rollout_ref.model.lora_alpha=128 \ + actor_rollout_ref.model.target_modules="['q_proj_moe_gen','k_proj_moe_gen','v_proj_moe_gen','o_proj_moe_gen']" \ + actor_rollout_ref.actor.optim.lr=1e-3 \ + actor_rollout_ref.actor.optim.weight_decay=0.0001 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ + actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ + actor_rollout_ref.rollout.name=$ENGINE \ + actor_rollout_ref.rollout.n=8 \ + actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ + actor_rollout_ref.rollout.agent.num_workers=2 \ + actor_rollout_ref.rollout.load_format=auto \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.rollout.num_inference_steps=15 \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=15 \ + +actor_rollout_ref.rollout.extra_configs.noise_level=1.2 \ + +actor_rollout_ref.rollout.extra_configs.sde_type="sde" \ + +actor_rollout_ref.rollout.extra_configs.sde_window_size=2 \ + +actor_rollout_ref.rollout.extra_configs.sde_window_range="[0,5]" \ + +actor_rollout_ref.rollout.extra_configs.max_sequence_length=256 \ + +actor_rollout_ref.rollout.val_kwargs.extra_configs.noise_level=0.0 \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=examples.flowgrpo_trainer.vllm_omni.pipeline_bagel.BagelPipelineWithLogProb \ + +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.stage_configs_path=$BAGEL_STAGE_CONFIG \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + reward.num_workers=1 \ + reward.reward_manager.name=visual \ + reward.reward_model.enable=True \ + reward.reward_model.model_path=$REWARD_MODEL_PATH \ + reward.reward_model.rollout.name=$REWARD_ENGINE \ + reward.reward_model.rollout.tensor_model_parallel_size=4 \ + reward.custom_reward_function.path=$reward_path \ + reward.custom_reward_function.name=compute_score_ocr \ + trainer.use_legacy_worker_impl=disable \ + trainer.logger='["console", "wandb"]' \ + trainer.project_name=flow_grpo \ + trainer.experiment_name=bagel_ocr_lora \ + trainer.log_val_generations=4 \ + trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.total_training_steps=50 $@ diff --git a/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py b/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py index 3f0764e52ae..cbd4f68db75 100644 --- a/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py +++ b/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py @@ -51,6 +51,14 @@ class FlowMatchSDEDiscreteScheduler(FlowMatchEulerDiscreteScheduler): and diffusers v0.37 branch. """ + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + # Use nearest-neighbor matching to avoid float32/float64 precision mismatch + # between timesteps stored during rollout and scheduler's internal timesteps. + diffs = (schedule_timesteps - timestep).abs() + return diffs.argmin().item() + def step( self, model_output: torch.FloatTensor, @@ -169,11 +177,14 @@ def sample_previous_step( sigma = self.sigmas[sigma_idx] sigma_prev = self.sigmas[sigma_idx + 1] else: - sigma_idx = torch.tensor([self.index_for_timestep(t) for t in timestep]) + sigma_idx = torch.tensor([self.index_for_timestep(t.cpu()) for t in timestep]) sigma = self.sigmas[sigma_idx].view(-1, *([1] * (len(sample.shape) - 1))) sigma_prev = self.sigmas[sigma_idx + 1].view(-1, *([1] * (len(sample.shape) - 1))) - sigma_max = self.sigmas[1] + # Move scheduler tensors to the same device as sample + sigma = sigma.to(device=sample.device, dtype=sample.dtype) + sigma_prev = sigma_prev.to(device=sample.device, dtype=sample.dtype) + sigma_max = self.sigmas[1].to(device=sample.device, dtype=sample.dtype) dt = sigma_prev - sigma if sde_type == "sde": diff --git a/examples/flowgrpo_trainer/test_bagel_train.py b/examples/flowgrpo_trainer/test_bagel_train.py new file mode 100644 index 00000000000..895b760f437 --- /dev/null +++ b/examples/flowgrpo_trainer/test_bagel_train.py @@ -0,0 +1,167 @@ +"""Generate images with BagelForTraining using Classifier-Free Guidance (CFG). + +Matches the original vllm-omni inference pipeline EXACTLY: + * 3-branch CFG: cfg_text_scale=4.0, cfg_img_scale=1.5 + For text2img (no image input), cfg_img branch == gen branch, + so effective formula is: v = 5.5*v_cond - 4.5*v_uncond + * cfg_interval = [0.4, 1.0], cfg_renorm_type = "global", cfg_renorm_min = 0.0 + * Timestep: linspace(1, 0, num_steps) → num_steps-1 actual denoising steps + * timestep_shift = 3.0 + * Default seed = 52 (bagel_single_stage.yaml) + +Dependencies: torch, PIL, transformers (via bagel_model) +NO dependency on vllm or vllm-omni. +""" +import os +import sys +import torch +from PIL import Image + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "diffusers")) + +from bagel_model import ( + BagelForTraining, + AutoEncoder, + load_ae, + load_tokenizer, + get_flattened_position_ids, +) + +MODEL_PATH = "/proj-tango-pvc/users/zhipeng.wang/workspace/models/BAGEL-7B-MoT" +OUTPUT_DIR = "/proj-tango-pvc/users/zhipeng.wang/workspace/verl/examples/flowgrpo_trainer/outputs" +DEVICE = "cuda" +DTYPE = torch.bfloat16 + +NUM_STEPS = 50 +TIMESTEP_SHIFT = 3.0 +CFG_TEXT_SCALE = 4.0 +CFG_IMG_SCALE = 1.5 +CFG_INTERVAL = (0.4, 1.0) +CFG_RENORM_MIN = 0.0 +CFG_RENORM_TYPE = "global" +IMG_H, IMG_W = None, None # auto-detect from model.config.max_latent_size + + +def unpatchify(patches, h, w, ps, c): + B = patches.shape[0] + return torch.einsum("bhwpqc->bchpwq", + patches.view(B, h, w, ps, ps, c)).contiguous().view(B, c, h * ps, w * ps) + + +@torch.no_grad() +def generate_images_cfg(model, vae, text_token_ids, latent_pos_ids, img_h, img_w, seed=52): + """Full denoising with 3-branch CFG + renorm (matches vllm-omni pipeline).""" + model.eval() + B = 1 if text_token_ids is None else text_token_ids.shape[0] + L_latent = latent_pos_ids.shape[1] + + # Match vllm-omni: generate noise on CPU in float32, then move to CUDA. + # CPU and CUDA RNG produce different numbers even with the same seed. + torch.manual_seed(seed) + x_t = torch.randn(B, L_latent, model.config.patch_latent_dim).to(DEVICE) + + # vllm-omni: linspace(1, 0, num_timesteps) with num_timesteps=50 → 49 actual steps + timesteps = torch.linspace(1, 0, NUM_STEPS, device=DEVICE) + timesteps = TIMESTEP_SHIFT * timesteps / (1 + (TIMESTEP_SHIFT - 1) * timesteps) + dts = timesteps[:-1] - timesteps[1:] + timesteps = timesteps[:-1] + + # vllm-omni wraps the entire denoising loop in autocast; + # x_t stays float32 while model forward runs in bfloat16. + with torch.autocast(device_type="cuda", dtype=DTYPE): + for i, t in enumerate(timesteps): + ts = torch.full((B,), t.item(), device=DEVICE, dtype=DTYPE) + + v_cond = model( + hidden_states=x_t, timestep=ts, + text_token_ids=text_token_ids, latent_pos_ids=latent_pos_ids, + )[0] + + in_cfg_window = t.item() > CFG_INTERVAL[0] and t.item() <= CFG_INTERVAL[1] + cfg_text_scale = CFG_TEXT_SCALE if in_cfg_window else 1.0 + cfg_img_scale = CFG_IMG_SCALE if in_cfg_window else 1.0 + use_cfg = cfg_text_scale > 1.0 + + if use_cfg: + v_uncond = model( + hidden_states=x_t, timestep=ts, + text_token_ids=None, latent_pos_ids=latent_pos_ids, + )[0] + + v_text = v_uncond + cfg_text_scale * (v_cond - v_uncond) + cfg_img_v_t = v_cond + v_guided = cfg_img_v_t + cfg_img_scale * (v_text - cfg_img_v_t) + + if CFG_RENORM_TYPE == "global": + norm_cond = torch.norm(v_cond.float()) + norm_guided = torch.norm(v_guided.float()) + elif CFG_RENORM_TYPE == "channel": + norm_cond = torch.norm(v_cond.float(), dim=-1, keepdim=True) + norm_guided = torch.norm(v_guided.float(), dim=-1, keepdim=True) + else: + raise ValueError(f"Unsupported renorm type: {CFG_RENORM_TYPE}") + scale = (norm_cond / (norm_guided + 1e-8)).clamp(min=CFG_RENORM_MIN, max=1.0) + v_t = v_guided * scale + else: + v_t = v_cond + + x_t = x_t - v_t * dts[i] + + latent = unpatchify(x_t, img_h, img_w, model.config.latent_patch_size, model.config.latent_channel) + pixels = vae.decode(latent.float()) + pixels = ((pixels * 0.5 + 0.5).clamp(0, 1) * 255).to(torch.uint8) + return pixels + + +def main(): + os.makedirs(OUTPUT_DIR, exist_ok=True) + + print("Loading BAGEL model...") + model = BagelForTraining.from_pretrained(MODEL_PATH, torch_dtype=DTYPE).to(DEVICE) + print(f" {sum(p.numel() for p in model.parameters()) / 1e9:.2f}B params") + + print("Loading VAE...") + ae_path = os.path.join(MODEL_PATH, "ae.safetensors") + vae, _ = load_ae(ae_path) + vae = vae.to(DEVICE).eval() + + print("Loading tokenizer...") + tokenizer, new_token_ids = load_tokenizer(MODEL_PATH) + + img_h = IMG_H if IMG_H is not None else model.config.max_latent_size + img_w = IMG_W if IMG_W is not None else model.config.max_latent_size + latent_ds = model.config.latent_patch_size * model.config.vae_downsample + H_px = img_h * latent_ds + W_px = img_w * latent_ds + vae_pos_ids = get_flattened_position_ids( + H_px, W_px, latent_ds, model.config.max_latent_size, + ).to(DEVICE) + + prompts = [ + "A cute cat", + ] + + print(f"\nGenerating {len(prompts)} images at {H_px}x{W_px}px (latent {img_h}x{img_w})") + print(f" CFG: text_scale={CFG_TEXT_SCALE}, img_scale={CFG_IMG_SCALE}, " + f"interval={CFG_INTERVAL}, renorm={CFG_RENORM_TYPE}") + print(f" Steps: {NUM_STEPS} (linspace→{NUM_STEPS-1} actual), shift={TIMESTEP_SHIFT}\n") + + for idx, prompt in enumerate(prompts): + text_ids = tokenizer.encode(prompt) + full_ids = [new_token_ids['bos_token_id']] + text_ids + [new_token_ids['eos_token_id']] + text_token_ids = torch.tensor([full_ids], dtype=torch.long, device=DEVICE) + pos_ids = vae_pos_ids.unsqueeze(0) + + print(f" [{idx}] '{prompt}'") + pixels = generate_images_cfg(model, vae, text_token_ids, pos_ids, img_h, img_w, seed=52 + idx) + img = Image.fromarray(pixels[0].permute(1, 2, 0).cpu().numpy()) + + path = os.path.join(OUTPUT_DIR, f"sample{idx}_cfg.png") + img.save(path) + print(f" Saved: {path}") + + print("\nDone!") + + +if __name__ == "__main__": + main() diff --git a/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py b/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py new file mode 100644 index 00000000000..a4f49764f37 --- /dev/null +++ b/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py @@ -0,0 +1,149 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom vllm-omni pipeline for BAGEL RL rollouts with VeRL. + +Extends :class:`BagelPipeline` to: +* Replace the scheduler with an SDE scheduler for stochastic denoising + with log-probability recording. +* Always enable trajectory recording. +* Read SDE kwargs from ``sampling_params.extra_args``. +* Return RL artifacts in ``DiffusionOutput.custom_output``. + +Loaded via ``custom_pipeline_args``: + +.. code-block:: python + + custom_pipeline_args={ + "pipeline_class": "examples.flowgrpo_trainer.vllm_omni.pipeline_bagel.BagelPipelineWithLogProb" + } +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import numpy as np +import torch +from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.models.bagel.pipeline_bagel import BagelPipeline +from vllm_omni.diffusion.request import OmniDiffusionRequest + +from ..scheduler import FlowMatchSDEDiscreteScheduler + +logger = logging.getLogger(__name__) + + +def _to_cpu_tensor(v): + """Convert to a single CPU tensor, stacking a list of tensors if needed.""" + if isinstance(v, torch.Tensor): + return v.detach().cpu() + if isinstance(v, list): + tensors = [x.detach().cpu() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in v] + return torch.stack(tensors) if tensors else None + return v + + +@dataclass +class _AdapterStepOutput: + """Adapter output matching what bagel_transformer.generate_image expects.""" + + prev_sample: torch.Tensor + log_prob: torch.Tensor | None + + +class _BagelSchedulerAdapter: + """Wraps the diffusers-based FlowMatchSDEDiscreteScheduler to match + BAGEL's calling convention: ``step(v_t, sigma, x_t, dt, **kwargs)``. + + BAGEL's transformer calls ``scheduler.step(model_output, timesteps[i], + sample, dts[i], **scheduler_kwargs)`` with 4 positional args, while the + diffusers scheduler takes ``step(model_output, timestep, sample, **kwargs)`` + and computes dt internally. This adapter bridges the gap. + """ + + def __init__(self, inner: FlowMatchSDEDiscreteScheduler): + self._inner = inner + + def __getattr__(self, name): + return getattr(self._inner, name) + + def step( + self, + model_output: torch.Tensor, + sigma: float | torch.Tensor, + sample: torch.Tensor, + dt: float | torch.Tensor, # noqa: ARG002 — not used, inner computes from timestep schedule + **kwargs, + ) -> _AdapterStepOutput: + out = self._inner.step( + model_output=model_output, + timestep=sigma, + sample=sample, + return_dict=False, + **kwargs, + ) + # step() with return_dict=False returns (prev_sample, log_prob, prev_sample_mean, std_dev_t) + prev_sample, log_prob = out[0], out[1] + return _AdapterStepOutput(prev_sample=prev_sample, log_prob=log_prob) + + +class BagelPipelineWithLogProb(BagelPipeline): + """BAGEL pipeline variant for RL rollouts with VeRL.""" + + def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): + super().__init__(od_config=od_config, prefix=prefix) + inner = FlowMatchSDEDiscreteScheduler() + self.scheduler = _BagelSchedulerAdapter(inner) + logger.info("BagelPipelineWithLogProb: SDE scheduler enabled for RL rollouts") + + def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: + # Force trajectory recording on for RL + req.sampling_params.return_trajectory_latents = True + + # Read SDE scheduler kwargs from extra_args + extra_args = req.sampling_params.extra_args + logprobs = extra_args.get("logprobs", True) + self.scheduler_kwargs = {k: extra_args[k] for k in ("noise_level", "sde_type", "generator") if k in extra_args} + self.scheduler_kwargs["return_logprobs"] = logprobs + + # Per-request scheduler setup: compute BAGEL's shifted sigmas so + # the inner SDE scheduler's sigma schedule matches what + # generate_image() computes internally. + assert req.sampling_params.num_inference_steps is not None, "num_inference_steps must be set for RL rollouts" + num_timesteps = req.sampling_params.num_inference_steps + timestep_shift = 3.0 # must match BagelPipeline.forward() hardcoded value + + t = np.linspace(1, 0, num_timesteps) + t_shifted = timestep_shift * t / (1 + (timestep_shift - 1) * t) + sigmas = t_shifted[:-1].tolist() # drop terminal 0; set_timesteps appends it + + inner = self.scheduler._inner + inner.set_shift(1.0) # identity — sigmas already shifted + inner.set_timesteps(sigmas=sigmas) + inner.set_begin_index(0) + + output = super().forward(req) + + # Enrich custom_output with RL-specific fields (must be tensors for batch stacking) + custom = output.custom_output or {} + if output.trajectory_latents is not None: + custom["all_latents"] = _to_cpu_tensor(output.trajectory_latents) + if output.trajectory_timesteps is not None: + custom["all_timesteps"] = _to_cpu_tensor(output.trajectory_timesteps) + if output.trajectory_log_probs is not None: + custom["all_log_probs"] = _to_cpu_tensor(output.trajectory_log_probs) + output.custom_output = custom + + return output diff --git a/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py b/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py index 8a4b9bf7d11..b20d8aff38e 100644 --- a/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py +++ b/examples/flowgrpo_trainer/vllm_omni/pipeline_qwenimage.py @@ -49,20 +49,20 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): def _get_qwen_prompt_embeds( self, - prompt_ids: torch.Tensor, + prompt_token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, dtype: torch.dtype | None = None, ): dtype = dtype or self.text_encoder.dtype if attention_mask is None: - attention_mask = torch.ones_like(prompt_ids, dtype=torch.long) + attention_mask = torch.ones_like(prompt_token_ids, dtype=torch.long) - prompt_ids = prompt_ids.unsqueeze(0) if prompt_ids.ndim == 1 else prompt_ids + prompt_token_ids = prompt_token_ids.unsqueeze(0) if prompt_token_ids.ndim == 1 else prompt_token_ids attention_mask = attention_mask.unsqueeze(0) if attention_mask.ndim == 1 else attention_mask drop_idx = self.prompt_template_encode_start_idx encoder_hidden_states = self.text_encoder( - input_ids=prompt_ids.to(self.device), + input_ids=prompt_token_ids.to(self.device), attention_mask=attention_mask.to(self.device), output_hidden_states=True, ) @@ -84,21 +84,23 @@ def _get_qwen_prompt_embeds( def encode_prompt( self, - prompt_ids: torch.Tensor, + prompt_token_ids: torch.Tensor, attention_mask: torch.Tensor | None = None, num_images_per_prompt: int = 1, prompt_embeds: torch.Tensor | None = None, prompt_embeds_mask: torch.Tensor | None = None, max_sequence_length: int = 1024, ): - prompt_ids = prompt_ids.unsqueeze(0) if prompt_ids.ndim == 1 else prompt_ids + prompt_token_ids = prompt_token_ids.unsqueeze(0) if prompt_token_ids.ndim == 1 else prompt_token_ids attention_mask = ( attention_mask.unsqueeze(0) if attention_mask is not None and attention_mask.ndim == 1 else attention_mask ) - batch_size = prompt_ids.shape[0] if prompt_embeds is None else prompt_embeds.shape[0] + batch_size = prompt_token_ids.shape[0] if prompt_embeds is None else prompt_embeds.shape[0] if prompt_embeds is None: - prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt_ids, attention_mask=attention_mask) + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( + prompt_token_ids, attention_mask=attention_mask + ) prompt_embeds = prompt_embeds[:, :max_sequence_length] prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] @@ -215,7 +217,7 @@ def diffuse( def forward( self, req: OmniDiffusionRequest, - prompt_ids: torch.Tensor | list[int] | None = None, + prompt_token_ids: torch.Tensor | list[int] | None = None, prompt_mask: torch.Tensor | None = None, negative_prompt_ids: torch.Tensor | list[int] | None = None, negative_prompt_mask: torch.Tensor | None = None, @@ -245,7 +247,7 @@ def forward( # Extract prompt data from OmniCustomPrompt in req.prompts[0] custom_prompt = req.prompts[0] if req.prompts else {} if isinstance(custom_prompt, dict): - prompt_ids = custom_prompt.get("prompt_ids", prompt_ids) + prompt_token_ids = custom_prompt.get("prompt_token_ids", prompt_token_ids) prompt_mask = custom_prompt.get("prompt_mask", prompt_mask) negative_prompt_ids = custom_prompt.get("negative_prompt_ids", negative_prompt_ids) negative_prompt_mask = custom_prompt.get("negative_prompt_mask", negative_prompt_mask) @@ -276,14 +278,14 @@ def forward( self._current_timestep = None self._interrupt = False - if prompt_ids is not None: - if isinstance(prompt_ids, list): - prompt_ids = torch.tensor(prompt_ids, device=self.device) - batch_size = prompt_ids.shape[0] if prompt_ids.ndim == 2 else 1 + if prompt_token_ids is not None: + if isinstance(prompt_token_ids, list): + prompt_token_ids = torch.tensor(prompt_token_ids, device=self.device) + batch_size = prompt_token_ids.shape[0] if prompt_token_ids.ndim == 2 else 1 elif prompt_embeds is not None: batch_size = prompt_embeds.shape[0] else: - # Both prompt_ids and prompt_embeds are None (e.g. during warmup/dummy run). + # Both prompt_token_ids and prompt_embeds are None (e.g. during warmup/dummy run). # Return a minimal dummy output to avoid crashing. return DiffusionOutput(output=None, custom_output={}) @@ -296,7 +298,7 @@ def forward( do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( - prompt_ids=prompt_ids, + prompt_token_ids=prompt_token_ids, attention_mask=prompt_mask, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, @@ -305,7 +307,7 @@ def forward( ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - prompt_ids=negative_prompt_ids, + prompt_token_ids=negative_prompt_ids, attention_mask=negative_prompt_mask, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, @@ -398,12 +400,16 @@ def forward( return DiffusionOutput( output=_maybe_to_cpu(image), custom_output={ - "all_latents": _maybe_to_cpu(all_latents), - "all_log_probs": _maybe_to_cpu(all_log_probs), - "all_timesteps": _maybe_to_cpu(all_timesteps), - "prompt_embeds": _maybe_to_cpu(prompt_embeds), - "prompt_embeds_mask": _maybe_to_cpu(prompt_embeds_mask), - "negative_prompt_embeds": _maybe_to_cpu(negative_prompt_embeds), - "negative_prompt_embeds_mask": _maybe_to_cpu(negative_prompt_embeds_mask), + "all_latents": _maybe_to_cpu(all_latents[0]), + "all_log_probs": _maybe_to_cpu(all_log_probs[0]) if all_log_probs is not None else None, + "all_timesteps": _maybe_to_cpu(all_timesteps[0]), + "prompt_embeds": _maybe_to_cpu(prompt_embeds[0]), + "prompt_embeds_mask": _maybe_to_cpu(prompt_embeds_mask[0]) if prompt_embeds_mask is not None else None, + "negative_prompt_embeds": _maybe_to_cpu(negative_prompt_embeds[0]) + if negative_prompt_embeds is not None + else None, + "negative_prompt_embeds_mask": _maybe_to_cpu(negative_prompt_embeds_mask[0]) + if negative_prompt_embeds_mask is not None + else None, }, ) diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_abort.py b/tests/workers/rollout/rollout_vllm/test_vllm_abort.py index cad7cbb5e83..28f84dedc7b 100644 --- a/tests/workers/rollout/rollout_vllm/test_vllm_abort.py +++ b/tests/workers/rollout/rollout_vllm/test_vllm_abort.py @@ -120,13 +120,13 @@ def test_vllm_abort(): "Write about the French Revolution and its consequences.", ] - all_prompt_ids = [] + all_prompt_token_ids = [] for prompt in prompts[:NUM_PROMPTS]: messages = [{"role": "user", "content": prompt}] - prompt_ids = normalize_token_ids( + prompt_token_ids = normalize_token_ids( tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) ) - all_prompt_ids.append(prompt_ids) + all_prompt_token_ids.append(prompt_token_ids) print(f"Prepared {NUM_PROMPTS} prompts") # ==================== Start Generations and Abort ==================== @@ -141,11 +141,11 @@ def test_vllm_abort(): # Start all generations concurrently print(f"\n Starting {NUM_PROMPTS} generations...") generate_refs = [] - for i, prompt_ids in enumerate(all_prompt_ids): + for i, prompt_token_ids in enumerate(all_prompt_token_ids): request_id = f"abort_test_{i}_{uuid4().hex[:8]}" ref = server_handle.generate.remote( request_id=request_id, - prompt_ids=prompt_ids, + prompt_token_ids=prompt_token_ids, sampling_params=sampling_params, image_data=None, ) diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py b/tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py new file mode 100644 index 00000000000..797383e1044 --- /dev/null +++ b/tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py @@ -0,0 +1,342 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +E2E test for BAGEL RL pipeline via vLLMOmniHttpServer. + +Uses verl's rollout server with BAGEL's multi-stage pipeline +(thinker on GPU 0, DiT on GPU 1) and BagelPipelineWithLogProb. + +Usage: + pytest tests/workers/rollout/rollout_vllm/test_vllm_omni_bagel_generate.py -v -s +""" + +import json +import os +import tempfile +from pathlib import Path +from uuid import uuid4 + +import numpy as np +import pytest +import ray +import torch +from omegaconf import OmegaConf +from safetensors.torch import save_file + +from verl.workers.rollout.replica import DiffusionOutput, RolloutMode +from verl.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniHttpServer + +MODEL_PATH = Path(os.path.expanduser("~/models/tiny-random/bagel")) +STAGE_CONFIG = os.environ.get("BAGEL_STAGE_CONFIG", "") # TODO: Point to the location on the CI server by default + +DEFAULT_PROMPT = ( + "a beautiful sunset over the ocean with vibrant orange and purple clouds reflecting on the calm water surface" +) + + +# --------------------------------------------------------------------- +# 👇 Test Helper Functions & Fixtures 👇 +# --------------------------------------------------------------------- + + +def _tokenize_prompt(text: str) -> list[int]: + """Tokenize a text prompt into token IDs for BAGEL.""" + from transformers import AutoTokenizer + + from verl.utils.tokenizer import normalize_token_ids + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + token_ids = normalize_token_ids(tokenizer.encode(text)) + return token_ids + + +@pytest.fixture(scope="module") +def init_server(): + """Create and launch a vLLMOmniHttpServer Ray actor with BAGEL.""" + if not STAGE_CONFIG: + pytest.skip("BAGEL_STAGE_CONFIG env var not set") + + ray.init( + runtime_env={ + "env_vars": { + "TOKENIZERS_PARALLELISM": "true", + "NCCL_DEBUG": "WARN", + "VLLM_LOGGING_LEVEL": "INFO", + } + }, + ignore_reinit_error=True, + ) + + rollout_cfg = OmegaConf.create( + { + "_target_": "verl.workers.config.DiffusionRolloutConfig", + "name": "vllm_omni", + "mode": "async", + "tensor_model_parallel_size": 1, + "data_parallel_size": 1, + "pipeline_model_parallel_size": 1, + "gpu_memory_utilization": 0.9, + "max_num_batched_tokens": 32768, + "max_num_seqs": 1, + "max_model_len": 32768, + "dtype": "bfloat16", + "load_format": "auto", + "enforce_eager": True, + "enable_chunked_prefill": False, + "enable_prefix_caching": False, + "enable_sleep_mode": False, + "free_cache_engine": True, + "disable_log_stats": True, + "n": 1, + "num_inference_steps": 10, + "engine_kwargs": { + "vllm_omni": { + "custom_pipeline": "examples.flowgrpo_trainer.vllm_omni.pipeline_bagel.BagelPipelineWithLogProb", + "stage_configs_path": STAGE_CONFIG, + } + }, + } + ) + + model_cfg = OmegaConf.create( + { + "_target_": "verl.workers.config.DiffusionModelConfig", + "path": MODEL_PATH, + "architecture": "OmniBagelForConditionalGeneration", + "trust_remote_code": True, + "load_tokenizer": False, + } + ) + + ServerCls = ray.remote(vLLMOmniHttpServer) + server = ServerCls.options( + runtime_env={ + "env_vars": { + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", + "NCCL_CUMEM_ENABLE": "0", + } + }, + max_concurrency=10, + ).remote( + config=rollout_cfg, + model_config=model_cfg, + rollout_mode=RolloutMode.STANDALONE, + workers=[], + replica_rank=0, + node_rank=0, + gpus_per_node=2, + nnodes=1, + cuda_visible_devices="0,1", + ) + + ray.get(server.launch_server.remote()) + + yield server + + ray.shutdown() + + +# --------------------------------------------------------------------- +# 👇 Tests 👇 +# --------------------------------------------------------------------- + + +def test_generate(init_server): + """generate() returns a valid DiffusionOutput with CHW image in [0, 1].""" + server = init_server + + request_id = f"test_{uuid4().hex[:8]}" + output = ray.get( + server.generate.remote( + prompt_token_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={ + "num_inference_steps": 10, + }, + request_id=request_id, + ), + timeout=300, + ) + + assert isinstance(output, DiffusionOutput) + assert len(output.diffusion_output) == 3, f"Expected 3 channels (CHW), got {len(output.diffusion_output)}" + h, w = len(output.diffusion_output[0]), len(output.diffusion_output[0][0]) + assert h > 0 and w > 0 + assert output.stop_reason in ("completed", "aborted", None) + + # spot-check pixel range + assert 0.0 <= output.diffusion_output[0][0][0] <= 1.0 + + print(f"image: C=3 H={h} W={w} stop_reason={output.stop_reason}") + + +def test_generate_with_logprobs(init_server): + """generate() with SDE scheduler returns non-empty log_probs and RL artifacts.""" + server = init_server + + request_id = f"test_lp_{uuid4().hex[:8]}" + output = ray.get( + server.generate.remote( + prompt_token_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={ + "num_inference_steps": 10, + "noise_level": 0.7, + "sde_type": "sde", + "logprobs": True, + }, + request_id=request_id, + ), + timeout=300, + ) + + assert isinstance(output, DiffusionOutput) + assert len(output.diffusion_output) == 3 + + lp = output.log_probs + assert lp is not None, "log_probs should be present when logprobs=True" + print(f"log_probs: shape={getattr(lp, 'shape', len(lp))}") + + extra = output.extra_fields + assert extra.get("all_latents") is not None, "all_latents should be present" + assert extra.get("all_timesteps") is not None, "all_timesteps should be present" + print(f"all_latents: shape={getattr(extra['all_latents'], 'shape', len(extra['all_latents']))}") + print(f"all_timesteps: shape={getattr(extra['all_timesteps'], 'shape', len(extra['all_timesteps']))}") + + +def test_generate_concurrent(init_server): + """Multiple concurrent generate() calls all return valid DiffusionOutput.""" + server = init_server + n_requests = 4 + + prompts = [ + "a beautiful sunset over the ocean with vibrant orange and purple clouds " + "reflecting on the calm water surface near a rocky coastline", + "a fluffy orange cat sitting on a wooden windowsill looking outside at " + "a garden full of colorful flowers on a bright sunny afternoon", + "a majestic mountain landscape covered with fresh white snow under a " + "clear blue sky with pine trees in the foreground and a frozen lake", + "a futuristic city at night with neon lights glowing on tall glass " + "skyscrapers and flying vehicles soaring between the buildings", + ] + + refs = [] + for i in range(n_requests): + rid = f"concurrent_{i}_{uuid4().hex[:8]}" + ref = server.generate.remote( + prompt_token_ids=_tokenize_prompt(prompts[i]), + sampling_params={"num_inference_steps": 10}, + request_id=rid, + ) + refs.append(ref) + + results = ray.get(refs, timeout=600) + + for i, res in enumerate(results): + assert isinstance(res, DiffusionOutput), f"Request {i}: expected DiffusionOutput" + assert len(res.diffusion_output) == 3, f"Request {i}: expected 3 channels" + assert res.stop_reason in ("completed", "aborted", None) + + print(f"All {n_requests} concurrent requests returned valid DiffusionOutput") + + +# --------------------------------------------------------------------- +# 👇 LoRA helpers 👇 +# --------------------------------------------------------------------- + +# Tiny BAGEL: hidden_size=64, 2 Q heads, 2 KV heads, head_dim=32 +# QKV packed dim = (2+2+2)*32 = 192 +_LORA_DIM = 64 +_LORA_QKV_DIM = 192 +_LORA_MODULE = "bagel.language_model.model.layers.0.self_attn.qkv_proj" +_LORA_RANK = 4 + + +def _make_synthetic_lora(adapter_dir: Path): + """Create a synthetic rank-4 LoRA adapter on disk.""" + adapter_dir.mkdir(parents=True, exist_ok=True) + gen = torch.Generator().manual_seed(42) + lora_a = torch.randn((_LORA_RANK, _LORA_DIM), dtype=torch.float32, generator=gen) * 0.1 + lora_b = torch.randn((_LORA_QKV_DIM, _LORA_RANK), dtype=torch.float32, generator=gen) * 0.5 + save_file( + { + f"base_model.model.{_LORA_MODULE}.lora_A.weight": lora_a, + f"base_model.model.{_LORA_MODULE}.lora_B.weight": lora_b, + }, + str(adapter_dir / "adapter_model.safetensors"), + ) + (adapter_dir / "adapter_config.json").write_text( + json.dumps({"r": _LORA_RANK, "lora_alpha": _LORA_RANK, "target_modules": [_LORA_MODULE]}), + encoding="utf-8", + ) + return str(adapter_dir) + + +def test_generate_with_lora(init_server): + """LoRA adapter changes output and deactivation restores baseline.""" + from vllm_omni.lora.request import LoRARequest + + server = init_server + + with tempfile.TemporaryDirectory() as tmp_dir: + lora_path = _make_synthetic_lora(Path(tmp_dir) / "bagel_lora") + lora_request = LoRARequest(lora_name="test_lora", lora_int_id=42, lora_path=lora_path) + + # 1) Baseline (no LoRA) + baseline = ray.get( + server.generate.remote( + prompt_token_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={"num_inference_steps": 10}, + request_id=f"lora_base_{uuid4().hex[:8]}", + ), + timeout=300, + ) + + # 2) With LoRA + with_lora = ray.get( + server.generate.remote( + prompt_token_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={"num_inference_steps": 10}, + request_id=f"lora_on_{uuid4().hex[:8]}", + lora_request=lora_request, + lora_scale=1.0, + ), + timeout=300, + ) + + # 3) Deactivated (no LoRA again) + restored = ray.get( + server.generate.remote( + prompt_token_ids=_tokenize_prompt(DEFAULT_PROMPT), + sampling_params={"num_inference_steps": 10}, + request_id=f"lora_off_{uuid4().hex[:8]}", + ), + timeout=300, + ) + + assert isinstance(baseline, DiffusionOutput) + assert isinstance(with_lora, DiffusionOutput) + assert isinstance(restored, DiffusionOutput) + + base_arr = np.array(baseline.diffusion_output) + lora_arr = np.array(with_lora.diffusion_output) + + diff_lora = np.abs(base_arr - lora_arr).mean() + + print(f"LoRA diff from baseline: {diff_lora:.4f}") + + # LoRA should visibly change output + assert diff_lora > 0.001, f"LoRA had no effect: diff={diff_lora}" + # Output is not corrupted + assert diff_lora < 80, f"LoRA output looks corrupted: diff={diff_lora}" diff --git a/tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py b/tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py index ea3cd6075b1..1ed4103d33e 100644 --- a/tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py +++ b/tests/workers/rollout/rollout_vllm/test_vllm_omni_generate.py @@ -57,7 +57,7 @@ def _tokenize_prompt(text: str) -> list[int]: return token_ids -@pytest.fixture +@pytest.fixture(scope="module") def init_server(): """Create and launch a vLLMOmniHttpServer Ray actor with Qwen/Qwen-Image.""" model_path = MODEL_PATH @@ -149,7 +149,7 @@ def init_server(): def test_generate(init_server): """generate() returns a valid DiffusionOutput with CHW image in [0, 1].""" server = init_server - prompt_ids = _tokenize_prompt( + prompt_token_ids = _tokenize_prompt( "a beautiful sunset over the ocean with vibrant orange and purple clouds " "reflecting on the calm water surface near a rocky coastline" ) @@ -157,7 +157,7 @@ def test_generate(init_server): request_id = f"test_{uuid4().hex[:8]}" output = ray.get( server.generate.remote( - prompt_ids=prompt_ids, + prompt_token_ids=prompt_token_ids, sampling_params={ "num_inference_steps": 10, "true_cfg_scale": 4.0, @@ -184,7 +184,7 @@ def test_generate(init_server): def test_generate_with_logprobs(init_server): """generate() with logprobs=True returns non-empty log_probs (tensor or sequence).""" server = init_server - prompt_ids = _tokenize_prompt( + prompt_token_ids = _tokenize_prompt( "a futuristic city at night with neon lights glowing on tall glass " "skyscrapers and flying vehicles soaring between the buildings" ) @@ -192,7 +192,7 @@ def test_generate_with_logprobs(init_server): request_id = f"test_lp_{uuid4().hex[:8]}" output = ray.get( server.generate.remote( - prompt_ids=prompt_ids, + prompt_token_ids=prompt_token_ids, sampling_params={ "num_inference_steps": 10, "true_cfg_scale": 4.0, @@ -241,7 +241,7 @@ def test_generate_concurrent(init_server): for i in range(n_requests): rid = f"concurrent_{i}_{uuid4().hex[:8]}" ref = server.generate.remote( - prompt_ids=_tokenize_prompt(prompts[i]), + prompt_token_ids=_tokenize_prompt(prompts[i]), sampling_params={ "num_inference_steps": 10, "true_cfg_scale": 4.0, diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 6b9503ac181..f33b1615777 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -165,7 +165,7 @@ async def generate( try: output: TokenOutput | DiffusionOutput = await server.generate.remote( request_id=uuid4().hex, # use new request_id for each turn - prompt_ids=prompt_ids, + prompt_token_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data, video_data=video_data, diff --git a/verl/models/diffusers_model/base.py b/verl/models/diffusers_model/base.py index 33a7bcfcc44..63ab9857034 100644 --- a/verl/models/diffusers_model/base.py +++ b/verl/models/diffusers_model/base.py @@ -85,6 +85,16 @@ def get_class(cls, model_config: DiffusionModelConfig) -> type["DiffusionModelBa f"Set ``external_lib`` in DiffusionModelConfig to load your implementation." ) from None + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: + """Optional hook for custom model loading. + + Override this to load non-standard models (e.g. models not loadable + via ``diffusers.AutoModel``). Return ``None`` to fall back to the + default ``AutoModel.from_pretrained`` path in the FSDP engine. + """ + return None + @classmethod @abstractmethod def build_scheduler(cls, model_config: DiffusionModelConfig) -> SchedulerMixin: diff --git a/verl/trainer/config/diffusion_trainer.yaml b/verl/trainer/config/diffusion_trainer.yaml index 656a8172ecd..d480c6c5498 100644 --- a/verl/trainer/config/diffusion_trainer.yaml +++ b/verl/trainer/config/diffusion_trainer.yaml @@ -177,6 +177,89 @@ trainer: # mode: "auto", "enable", or "disable" use_legacy_worker_impl: disable +# profiler configs +global_profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # Profiling tool: choose between nsys, npu, torch, torch_memory, precision_debugger + tool: null + + # profile steps + steps: null + + # Whether to combine continuous steps into one database. + profile_continuous_steps: False + + # Path to save profiling contents + save_path: "outputs/profile" + + # Specific tool configs + global_tool_config: + + # nsys config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. + controller_nsight_options: + + trace: "cuda,nvtx,cublas,ucx" + + cuda-memory-usage: "true" + + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. + worker_nsight_options: + + trace: "cuda,nvtx,cublas,ucx" + + cuda-memory-usage: "true" + + cuda-graph-trace: "graph" + + capture-range: "cudaProfilerApi" + + capture-range-end: null + + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + trace_alloc_max_entries: 100_000 + + stack_depth: 32 + + context: "all" + + stacks: "all" + + kw_args: {} + + # msprobe precision debugger + precision_debugger: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + + config_path: null + + steps: null + + stages: null + + strict: False + # configs related to ray ray_kwargs: diff --git a/verl/trainer/diffusion/ray_diffusion_trainer.py b/verl/trainer/diffusion/ray_diffusion_trainer.py index fe0be9d15c1..dd505060bc1 100644 --- a/verl/trainer/diffusion/ray_diffusion_trainer.py +++ b/verl/trainer/diffusion/ray_diffusion_trainer.py @@ -588,6 +588,16 @@ def init_workers(self): wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config, "global_profiler.steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config, "global_profiler.steps") + if OmegaConf.select(self.config, "global_profiler.tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): @@ -666,6 +676,20 @@ def init_workers(self): # sleep all replicas to load checkpoint self.checkpoint_manager.sleep_replicas() + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile(profile_step=self.global_steps) + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -879,6 +903,7 @@ def fit(self): self.global_steps += 1 last_val_metrics = None self.max_steps_duration = 0 + prev_step_profile = False for epoch in range(current_epoch, self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: @@ -903,7 +928,21 @@ def fit(self): ) is_last_step = self.global_steps >= self.total_training_steps + + profile_steps = ( + self.config.global_profiler.steps + if OmegaConf.select(self.config, "global_profiler.steps") is not None + else None + ) + curr_step_profile = self.global_steps in profile_steps if profile_steps else False + profile_continuous = OmegaConf.select(self.config, "global_profiler.profile_continuous_steps") or False + with marked_timer("step", timing_raw): + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile if profile_continuous else curr_step_profile + ) + # generate a batch with marked_timer("gen", timing_raw, color="red"): gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) @@ -1016,6 +1055,15 @@ def fit(self): with marked_timer("update_weights", timing_raw, color="red"): self.checkpoint_manager.update_weights(self.global_steps) + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in profile_steps if profile_steps else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile if profile_continuous else curr_step_profile + ) + prev_step_profile = curr_step_profile + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/workers/engine/fsdp/diffusers_impl.py b/verl/workers/engine/fsdp/diffusers_impl.py index 47dcebeefb8..a1eb5fad8f0 100644 --- a/verl/workers/engine/fsdp/diffusers_impl.py +++ b/verl/workers/engine/fsdp/diffusers_impl.py @@ -163,8 +163,6 @@ def _init_device_mesh(self): raise NotImplementedError("Ulysses sequence parallel for Diffusers backend is not supported currently.") def _build_module(self): - from diffusers import AutoModel - from verl.utils.torch_dtypes import PrecisionType torch_dtype = self.engine_config.model_dtype @@ -175,6 +173,21 @@ def _build_module(self): torch_dtype = PrecisionType.to_dtype(torch_dtype) + # Allow registered DiffusionModelBase subclass to provide custom loading + from verl.models.diffusers_model import DiffusionModelBase + + model_cls = DiffusionModelBase.get_class(self.model_config) + module = model_cls.build_module(self.model_config, torch_dtype) + + if module is not None: + module.to(torch_dtype) + if not hasattr(module, "can_generate"): + module.can_generate = lambda: False + return module + + # Default path: load via diffusers AutoModel + from diffusers import AutoModel + init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): @@ -529,12 +542,12 @@ def prepare_model_inputs(self, micro_batch: TensorDict, step: int): """ latents = micro_batch["all_latents"] timesteps = micro_batch["all_timesteps"] - prompt_embeds = micro_batch["prompt_embeds"] - prompt_embeds_mask = micro_batch["prompt_embeds_mask"] - negative_prompt_embeds = micro_batch["negative_prompt_embeds"] - negative_prompt_embeds_mask = micro_batch["negative_prompt_embeds_mask"] + prompt_embeds = micro_batch.get("prompt_embeds", None) + prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None) + negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) + negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) - if prompt_embeds.is_nested: + if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 23c8594f080..bf5af2d9f37 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -438,7 +438,7 @@ def on_run_headless_done(future: asyncio.Future): async def generate( self, - prompt_ids: list[int], + prompt_token_ids: list[int], sampling_params: dict[str, Any], request_id: str, image_data: Optional[list[Any]] = None, @@ -446,14 +446,14 @@ async def generate( priority: int = 0, ) -> TokenOutput: """Generate sequence with token-in-token-out.""" - prompt_ids = normalize_token_ids(prompt_ids) + prompt_token_ids = normalize_token_ids(prompt_token_ids) # Calculate the maximum possible new tokens based on available context space # This serves as a safety upper bound - max_possible_tokens = self.config.max_model_len - len(prompt_ids) + max_possible_tokens = self.config.max_model_len - len(prompt_token_ids) if max_possible_tokens < 0: raise ValueError( - f"Prompt length ({len(prompt_ids)}) exceeds the model's maximum context length " + f"Prompt length ({len(prompt_token_ids)}) exceeds the model's maximum context length " f"({self.config.max_model_len})." ) @@ -468,7 +468,8 @@ async def generate( # Cap max_tokens by response_length to ensure tensor alignment, # and by remaining budget to prevent OOM in multi-turn rollouts. max_tokens = min( - self.config.response_length, self.config.prompt_length + self.config.response_length - len(prompt_ids) + self.config.response_length, + self.config.prompt_length + self.config.response_length - len(prompt_token_ids), ) # Clamp max_tokens to the valid range [0, max_possible_tokens] @@ -480,14 +481,14 @@ async def generate( sampling_params["logprobs"] = 0 if sampling_params.pop("logprobs", False) else None sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) - prompt_ids = qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) + prompt_token_ids = qwen2_5_vl_dedup_image_tokens(prompt_token_ids, self.model_config.processor) multi_modal_data = {} if image_data is not None: multi_modal_data["image"] = image_data if video_data is not None: multi_modal_data["video"] = video_data - prompt = TokensPrompt(prompt_token_ids=prompt_ids, multi_modal_data=multi_modal_data) + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, multi_modal_data=multi_modal_data) # Add lora request lora_request = None diff --git a/verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py index 6d21f66e710..11db0b95e00 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_omni_async_server.py @@ -81,6 +81,7 @@ def _get_engine_kwargs_key(self) -> str: def _preprocess_engine_kwargs(self, engine_kwargs: dict) -> None: # custom_pipeline is passed directly to run_server; not supported via CLI yet engine_kwargs.pop("custom_pipeline", None) + engine_kwargs.pop("stage_configs_path", None) def _get_worker_extension_cls(self) -> str: return "verl.workers.rollout.vllm_rollout.utils.vLLMOmniColocateWorkerExtension" @@ -100,10 +101,14 @@ async def run_server(self, args: argparse.Namespace): engine_args = asdict(engine_args) # TODO (mike): read custom_pipeline from CLI - custom_pipeline = self.config.engine_kwargs.get("vllm_omni", {}).get("custom_pipeline", None) + vllm_omni_kwargs = self.config.engine_kwargs.get("vllm_omni", {}) + custom_pipeline = vllm_omni_kwargs.get("custom_pipeline", None) if custom_pipeline is not None: engine_args["enable_dummy_pipeline"] = True engine_args["custom_pipeline_args"] = {"pipeline_class": custom_pipeline} + stage_configs_path = vllm_omni_kwargs.get("stage_configs_path") + if stage_configs_path is not None: + engine_args["stage_configs_path"] = stage_configs_path # TODO (mike): support parsing engine config from CLI engine_client = AsyncOmni(**engine_args) @@ -127,16 +132,19 @@ def _get_wake_up_tags(self) -> list[str]: async def generate( self, - prompt_ids: list[int], + prompt_token_ids: list[int], sampling_params: dict[str, Any], request_id: str, image_data: Optional[list[Any]] = None, video_data: Optional[list[Any]] = None, negative_prompt_ids: Optional[list[int]] = None, priority: int = 0, + lora_request: Optional[LoRARequest] = None, + lora_scale: float = 1.0, ) -> DiffusionOutput: """Generate sequence with token-in-image-out.""" - prompt_ids = normalize_token_ids(prompt_ids) + prompt_token_ids = normalize_token_ids(prompt_token_ids) + default_params_list = self.engine.default_sampling_params_list multi_modal_data = {} if image_data is not None: @@ -144,9 +152,8 @@ async def generate( if video_data is not None: multi_modal_data["video"] = video_data - # Add lora request - lora_request = None - if self.lora_as_adapter: + # Add lora request (caller-supplied takes precedence over lora_as_adapter) + if lora_request is None and self.lora_as_adapter: # Make sure we also check that the lora is already loaded in the engine lora_loaded = VLLM_LORA_INT_ID in await self.engine.list_loras() if lora_loaded: @@ -155,7 +162,9 @@ async def generate( ) # Build OmniCustomPrompt with pre-tokenized IDs - custom_prompt: OmniCustomPrompt = {"prompt_ids": prompt_ids} + custom_prompt: OmniCustomPrompt = {"prompt_token_ids": prompt_token_ids} + if len(default_params_list) > 1: + custom_prompt["modalities"] = ["image"] if negative_prompt_ids is not None: custom_prompt["negative_prompt_ids"] = negative_prompt_ids if multi_modal_data: @@ -172,13 +181,17 @@ async def generate( sampling_kwargs["extra_args"] = extra_args if lora_request is not None: sampling_kwargs["lora_request"] = lora_request + sampling_kwargs["lora_scale"] = lora_scale diffusion_sampling_params = OmniDiffusionSamplingParams(**sampling_kwargs) + # Build sampling params list: multi-stage models use defaults for non-diffusion stages + sampling_params_list = default_params_list[:-1] + [diffusion_sampling_params] + # Call AsyncOmni.generate() with the correct API generator = self.engine.generate( prompt=custom_prompt, request_id=request_id, - sampling_params_list=[diffusion_sampling_params], + sampling_params_list=sampling_params_list, ) # Get final response @@ -193,27 +206,17 @@ async def generate( mm_output = final_res.custom_output or {} if sampling_params.get("logprobs", False): - all_log_probs = mm_output.get("all_log_probs") - log_probs = all_log_probs[0] if all_log_probs is not None else None + log_probs = mm_output.get("all_log_probs") else: log_probs = None - all_latents = mm_output.get("all_latents") - all_timesteps = mm_output.get("all_timesteps") - prompt_embeds = mm_output.get("prompt_embeds") - prompt_embeds_mask = mm_output.get("prompt_embeds_mask") - negative_prompt_embeds = mm_output.get("negative_prompt_embeds") - negative_prompt_embeds_mask = mm_output.get("negative_prompt_embeds_mask") - extra_fields = { - "all_latents": all_latents[0] if all_latents is not None else None, - "all_timesteps": all_timesteps[0] if all_timesteps is not None else None, - "prompt_embeds": prompt_embeds[0] if prompt_embeds is not None else None, - "prompt_embeds_mask": prompt_embeds_mask[0] if prompt_embeds_mask is not None else None, - "negative_prompt_embeds": negative_prompt_embeds[0] if negative_prompt_embeds is not None else None, - "negative_prompt_embeds_mask": negative_prompt_embeds_mask[0] - if negative_prompt_embeds_mask is not None - else None, + "all_latents": mm_output.get("all_latents"), + "all_timesteps": mm_output.get("all_timesteps"), + "prompt_embeds": mm_output.get("prompt_embeds"), + "prompt_embeds_mask": mm_output.get("prompt_embeds_mask"), + "negative_prompt_embeds": mm_output.get("negative_prompt_embeds"), + "negative_prompt_embeds_mask": mm_output.get("negative_prompt_embeds_mask"), "global_steps": self.global_steps, } diff --git a/verl/workers/utils/padding.py b/verl/workers/utils/padding.py index 73ad94fec59..ca62bf980e9 100644 --- a/verl/workers/utils/padding.py +++ b/verl/workers/utils/padding.py @@ -173,7 +173,8 @@ def _to_nested(embeds: torch.Tensor, mask: torch.Tensor): torch.nested.as_nested_tensor(mask_list, layout=torch.jagged), ) - data["prompt_embeds"], data["prompt_embeds_mask"] = _to_nested(data["prompt_embeds"], data["prompt_embeds_mask"]) + if isinstance(data.get("prompt_embeds", None), torch.Tensor): + data["prompt_embeds"], data["prompt_embeds_mask"] = _to_nested(data["prompt_embeds"], data["prompt_embeds_mask"]) if isinstance(data.get("negative_prompt_embeds", None), torch.Tensor): data["negative_prompt_embeds"], data["negative_prompt_embeds_mask"] = _to_nested(