Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions examples/flowgrpo_trainer/diffusers/bagel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""BAGEL (MoT) diffusion model implementation for FlowGRPO training.

Registers as ``OmniBagelForConditionalGeneration`` so the FSDP engine
can load and train the model via the DiffusionModelBase registry.

Key differences from standard diffusion models (e.g. QwenImage):
* BAGEL is a *Mixture-of-Thought* transformer that processes text token
IDs and noisy latent patches in a single forward pass (no separate
text encoder).
* ``prompt_embeds`` are not used. Instead, the raw prompt token IDs
(available as ``micro_batch["prompts"]``) are passed directly to the
model as ``text_token_ids``.
* CFG uses a 3-branch scheme during rollout, but for FSDP training
(computing log-probs of the rollout trajectory) only the conditional
forward is needed.
"""

from __future__ import annotations

import logging
from typing import Optional

import numpy as np
import torch
from tensordict import TensorDict

from verl.models.diffusers_model import DiffusionModelBase
from verl.utils import tensordict_utils as tu
from verl.utils.device import get_device_name
from verl.workers.config import DiffusionModelConfig

from ..scheduler import FlowMatchSDEDiscreteScheduler
from .bagel_model import BagelForTraining, get_flattened_position_ids

logger = logging.getLogger(__name__)

TIMESTEP_SHIFT = 3.0 # must match BagelPipeline.forward() hardcoded value


@DiffusionModelBase.register("OmniBagelForConditionalGeneration")
class BagelDiffusion(DiffusionModelBase):
"""DiffusionModelBase wrapper for BagelForTraining (MoT)."""

# ------------------------------------------------------------------
# Custom model loading (BAGEL can't be loaded via diffusers.AutoModel)
# ------------------------------------------------------------------

@classmethod
def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype):
logger.info("Loading BagelForTraining from %s", model_config.local_path)
module = BagelForTraining.from_pretrained(
model_config.local_path, torch_dtype=torch_dtype
)
return module

# ------------------------------------------------------------------
# Scheduler
# ------------------------------------------------------------------

@classmethod
def build_scheduler(cls, model_config: DiffusionModelConfig):
scheduler = FlowMatchSDEDiscreteScheduler()
cls.set_timesteps(scheduler, model_config, get_device_name())
return scheduler

@classmethod
def set_timesteps(cls, scheduler: FlowMatchSDEDiscreteScheduler, model_config: DiffusionModelConfig, device: str):
num_inference_steps = model_config.num_inference_steps
t = np.linspace(1, 0, num_inference_steps)
t_shifted = TIMESTEP_SHIFT * t / (1 + (TIMESTEP_SHIFT - 1) * t)
sigmas = t_shifted[:-1].tolist()

scheduler.set_shift(1.0) # identity — sigmas already shifted
scheduler.set_timesteps(sigmas=sigmas)
scheduler.set_begin_index(0)

# ------------------------------------------------------------------
# Prepare model inputs
# ------------------------------------------------------------------

@classmethod
def _get_latent_pos_ids(cls, model_config: DiffusionModelConfig, module, device) -> torch.Tensor:
"""Compute latent position IDs from model config / image dimensions."""
config = module.config
img_h = model_config.height // (config.latent_patch_size * config.vae_downsample)
img_w = model_config.width // (config.latent_patch_size * config.vae_downsample)
# Clamp to max_latent_size
img_h = min(img_h, config.max_latent_size)
img_w = min(img_w, config.max_latent_size)
latent_ds = config.latent_patch_size * config.vae_downsample
H_px = img_h * latent_ds
W_px = img_w * latent_ds
pos_ids = get_flattened_position_ids(
H_px, W_px, latent_ds, config.max_latent_size,
)
return pos_ids.to(device)

@classmethod
def prepare_model_inputs(
cls,
module,
model_config: DiffusionModelConfig,
latents: torch.Tensor,
timesteps: torch.Tensor,
prompt_embeds: torch.Tensor,
prompt_embeds_mask: torch.Tensor,
negative_prompt_embeds: torch.Tensor,
negative_prompt_embeds_mask: torch.Tensor,
micro_batch: TensorDict,
step: int,
) -> tuple[dict, dict]:
B = latents.shape[0]
device = latents.device

hidden_states = latents[:, step]
timestep = timesteps[:, step]

# Extract text token IDs from prompt data
prompts = micro_batch["prompts"] # (B, L_prompt) padded
attention_mask = micro_batch["attention_mask"] # (B, L_prompt)

# Build per-sample text_token_ids (remove padding)
text_token_ids_list = []
for i in range(B):
mask = attention_mask[i].bool()
ids = prompts[i][mask]
text_token_ids_list.append(ids)

# Pad to same length within batch
max_text_len = max(ids.shape[0] for ids in text_token_ids_list)
text_token_ids = torch.zeros(B, max_text_len, dtype=torch.long, device=device)
for i, ids in enumerate(text_token_ids_list):
text_token_ids[i, :ids.shape[0]] = ids

# Compute latent position IDs
latent_pos_ids = cls._get_latent_pos_ids(model_config, module, device)
latent_pos_ids = latent_pos_ids.unsqueeze(0).expand(B, -1)

model_inputs = {
"hidden_states": hidden_states,
"timestep": timestep,
"text_token_ids": text_token_ids,
"latent_pos_ids": latent_pos_ids,
}

# For BAGEL, unconditional pass uses text_token_ids=None
negative_model_inputs = {
"hidden_states": hidden_states,
"timestep": timestep,
"text_token_ids": None,
"latent_pos_ids": latent_pos_ids,
}

return model_inputs, negative_model_inputs

# ------------------------------------------------------------------
# Forward + scheduler step
# ------------------------------------------------------------------

@classmethod
def forward_and_sample_previous_step(
cls,
module,
scheduler: FlowMatchSDEDiscreteScheduler,
model_config: DiffusionModelConfig,
model_inputs: dict[str, torch.Tensor],
negative_model_inputs: Optional[dict[str, torch.Tensor]],
scheduler_inputs: Optional[TensorDict | dict[str, torch.Tensor]],
step: int,
):
assert scheduler_inputs is not None
latents = scheduler_inputs["all_latents"]
timesteps = scheduler_inputs["all_timesteps"]

noise_pred = module(**model_inputs)[0]

# CFG during training (if configured)
true_cfg_scale = model_config.extra_configs.get("true_cfg_scale", 1.0)
if true_cfg_scale > 1.0:
assert negative_model_inputs is not None
neg_noise_pred = module(**negative_model_inputs)[0]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
noise_pred = comb_pred * (cond_norm / noise_norm)

_, log_prob, prev_sample_mean, std_dev_t = scheduler.sample_previous_step(
sample=latents[:, step].float(),
model_output=noise_pred.float(),
timestep=timesteps[:, step],
noise_level=model_config.extra_configs.get("noise_level", None),
prev_sample=latents[:, step + 1].float(),
sde_type=model_config.extra_configs.get("sde_type", None),
return_logprobs=True,
)
return log_prob, prev_sample_mean, std_dev_t
18 changes: 18 additions & 0 deletions examples/flowgrpo_trainer/diffusers/bagel_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def head_dim(self) -> int:
def patch_latent_dim(self) -> int:
return self.latent_patch_size ** 2 * self.latent_channel

def save_pretrained(self, save_directory: str):
"""Save config as JSON (compatible with diffusers checkpoint manager)."""
from dataclasses import asdict
output_path = os.path.join(save_directory, "config.json")
os.makedirs(save_directory, exist_ok=True)
with open(output_path, "w") as f:
json.dump(asdict(self), f, indent=4, sort_keys=True)

@classmethod
def from_model_path(cls, model_path: str) -> "BagelTrainingConfig":
cfg_path = os.path.join(model_path, "config.json")
Expand Down Expand Up @@ -751,6 +759,16 @@ def forward(

return (velocity,)

# ------------------------------------------------------------------
# PEFT / LoRA compatibility
# ------------------------------------------------------------------

def add_adapter(self, adapter_config, adapter_name: str = "default"):
"""Add a PEFT LoRA adapter (matches diffusers.ModelMixin API)."""
from peft import inject_adapter_in_model

inject_adapter_in_model(adapter_config, self, adapter_name)

# ------------------------------------------------------------------
# Checkpoint loading
# ------------------------------------------------------------------
Expand Down
32 changes: 21 additions & 11 deletions examples/flowgrpo_trainer/prepare_ocr_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
import pandas as pd


TEMPLATES = [
"Generate an image that clearly displays the text: '{text}'",
"Create a picture with the word '{text}' written on it",
"Produce an image containing the text '{text}' in a readable font",
"Design an image where the text '{text}' is prominently shown",
"Make an image with '{text}' written clearly in the center",
"Generate a clean image that shows the text: '{text}'",
"Create a visually clear image displaying '{text}'",
"Render an image with the following text: '{text}'",
SYSTEM_PROMPT = (
"Describe the image by detailing the color, shape, size, "
"texture, quantity, text, spatial relationships of the objects and background:"
)

TEMPLATES_WORD = [
"Create a picture with the word '{text}' written on it.",
]

TEMPLATES_PHRASE = [
"Create a picture with the sentence '{text}' written on it.",
]

# Simple words/phrases of varying difficulty
Expand Down Expand Up @@ -81,12 +83,20 @@ def generate_samples(n: int, seed: int = 42) -> list[dict]:
else:
text = random_alphanum(rng)

template = rng.choice(TEMPLATES)
is_phrase = " " in text
template = rng.choice(TEMPLATES_PHRASE if is_phrase else TEMPLATES_WORD)
prompt_text = template.format(text=text)

sample = {
"data_source": "ocr",
"prompt": [{"role": "user", "content": prompt_text}],
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_text},
],
"negative_prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": " "},
],
"reward_model": {"style": "rule", "ground_truth": text},
}
samples.append(sample)
Expand Down
15 changes: 13 additions & 2 deletions examples/flowgrpo_trainer/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json
import os
from typing import Optional
Expand All @@ -27,14 +28,20 @@
async def chat_complete(router_address: str, chat_complete_request: dict):
url = f"http://{router_address}/v1/chat/completions"
try:
timeout = aiohttp.ClientTimeout(total=None)
timeout = aiohttp.ClientTimeout(total=120)
session = aiohttp.ClientSession(timeout=timeout)
async with session.post(url, json=chat_complete_request) as resp:
output = await resp.text()
if not output or not output.strip():
return None
output = json.loads(output)
return ChatCompletion(**output)
except (json.JSONDecodeError, aiohttp.ClientError, asyncio.TimeoutError) as e:
print(f"[reward_fn] chat_complete failed: {type(e).__name__}: {e}")
return None
except Exception as e:
raise e
print(f"[reward_fn] chat_complete unexpected error: {type(e).__name__}: {e}")
return None
finally:
await session.close()

Expand Down Expand Up @@ -127,7 +134,11 @@ async def compute_score_ocr(
router_address=reward_router_address,
chat_complete_request=chat_complete_request,
)
if result is None or not result.choices:
return 0.0
grm_response = result.choices[0].message.content
if not grm_response:
return 0.0

# compute OCR score
text = grm_response
Expand Down
Loading
Loading