diff --git a/examples/flowgrpo_trainer/diffusers/bagel.py b/examples/flowgrpo_trainer/diffusers/bagel.py new file mode 100644 index 00000000000..4b97ebd3c47 --- /dev/null +++ b/examples/flowgrpo_trainer/diffusers/bagel.py @@ -0,0 +1,196 @@ +"""BAGEL (MoT) diffusion model implementation for FlowGRPO training. + +Registers as ``OmniBagelForConditionalGeneration`` so the FSDP engine +can load and train the model via the DiffusionModelBase registry. + +Key differences from standard diffusion models (e.g. QwenImage): + * BAGEL is a *Mixture-of-Thought* transformer that processes text token + IDs and noisy latent patches in a single forward pass (no separate + text encoder). + * ``prompt_embeds`` are not used. Instead, the raw prompt token IDs + (available as ``micro_batch["prompts"]``) are passed directly to the + model as ``text_token_ids``. + * CFG uses a 3-branch scheme during rollout, but for FSDP training + (computing log-probs of the rollout trajectory) only the conditional + forward is needed. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import numpy as np +import torch +from tensordict import TensorDict + +from verl.models.diffusers_model import DiffusionModelBase +from verl.utils import tensordict_utils as tu +from verl.utils.device import get_device_name +from verl.workers.config import DiffusionModelConfig + +from ..scheduler import FlowMatchSDEDiscreteScheduler +from .bagel_model import BagelForTraining, get_flattened_position_ids + +logger = logging.getLogger(__name__) + +TIMESTEP_SHIFT = 3.0 # must match BagelPipeline.forward() hardcoded value + + +@DiffusionModelBase.register("OmniBagelForConditionalGeneration") +class BagelDiffusion(DiffusionModelBase): + """DiffusionModelBase wrapper for BagelForTraining (MoT).""" + + # ------------------------------------------------------------------ + # Custom model loading (BAGEL can't be loaded via diffusers.AutoModel) + # ------------------------------------------------------------------ + + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype): + logger.info("Loading BagelForTraining from %s", model_config.local_path) + module = BagelForTraining.from_pretrained( + model_config.local_path, torch_dtype=torch_dtype + ) + return module + + # ------------------------------------------------------------------ + # Scheduler + # ------------------------------------------------------------------ + + @classmethod + def build_scheduler(cls, model_config: DiffusionModelConfig): + scheduler = FlowMatchSDEDiscreteScheduler() + cls.set_timesteps(scheduler, model_config, get_device_name()) + return scheduler + + @classmethod + def set_timesteps(cls, scheduler: FlowMatchSDEDiscreteScheduler, model_config: DiffusionModelConfig, device: str): + num_inference_steps = model_config.num_inference_steps + t = np.linspace(1, 0, num_inference_steps) + t_shifted = TIMESTEP_SHIFT * t / (1 + (TIMESTEP_SHIFT - 1) * t) + sigmas = t_shifted[:-1].tolist() + + scheduler.set_shift(1.0) # identity — sigmas already shifted + scheduler.set_timesteps(sigmas=sigmas) + scheduler.set_begin_index(0) + + # ------------------------------------------------------------------ + # Prepare model inputs + # ------------------------------------------------------------------ + + @classmethod + def _get_latent_pos_ids(cls, model_config: DiffusionModelConfig, module, device) -> torch.Tensor: + """Compute latent position IDs from model config / image dimensions.""" + config = module.config + img_h = model_config.height // (config.latent_patch_size * config.vae_downsample) + img_w = model_config.width // (config.latent_patch_size * config.vae_downsample) + # Clamp to max_latent_size + img_h = min(img_h, config.max_latent_size) + img_w = min(img_w, config.max_latent_size) + latent_ds = config.latent_patch_size * config.vae_downsample + H_px = img_h * latent_ds + W_px = img_w * latent_ds + pos_ids = get_flattened_position_ids( + H_px, W_px, latent_ds, config.max_latent_size, + ) + return pos_ids.to(device) + + @classmethod + def prepare_model_inputs( + cls, + module, + model_config: DiffusionModelConfig, + latents: torch.Tensor, + timesteps: torch.Tensor, + prompt_embeds: torch.Tensor, + prompt_embeds_mask: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + negative_prompt_embeds_mask: torch.Tensor, + micro_batch: TensorDict, + step: int, + ) -> tuple[dict, dict]: + B = latents.shape[0] + device = latents.device + + hidden_states = latents[:, step] + timestep = timesteps[:, step] + + # Extract text token IDs from prompt data + prompts = micro_batch["prompts"] # (B, L_prompt) padded + attention_mask = micro_batch["attention_mask"] # (B, L_prompt) + + # Build per-sample text_token_ids (remove padding) + text_token_ids_list = [] + for i in range(B): + mask = attention_mask[i].bool() + ids = prompts[i][mask] + text_token_ids_list.append(ids) + + # Pad to same length within batch + max_text_len = max(ids.shape[0] for ids in text_token_ids_list) + text_token_ids = torch.zeros(B, max_text_len, dtype=torch.long, device=device) + for i, ids in enumerate(text_token_ids_list): + text_token_ids[i, :ids.shape[0]] = ids + + # Compute latent position IDs + latent_pos_ids = cls._get_latent_pos_ids(model_config, module, device) + latent_pos_ids = latent_pos_ids.unsqueeze(0).expand(B, -1) + + model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "text_token_ids": text_token_ids, + "latent_pos_ids": latent_pos_ids, + } + + # For BAGEL, unconditional pass uses text_token_ids=None + negative_model_inputs = { + "hidden_states": hidden_states, + "timestep": timestep, + "text_token_ids": None, + "latent_pos_ids": latent_pos_ids, + } + + return model_inputs, negative_model_inputs + + # ------------------------------------------------------------------ + # Forward + scheduler step + # ------------------------------------------------------------------ + + @classmethod + def forward_and_sample_previous_step( + cls, + module, + scheduler: FlowMatchSDEDiscreteScheduler, + model_config: DiffusionModelConfig, + model_inputs: dict[str, torch.Tensor], + negative_model_inputs: Optional[dict[str, torch.Tensor]], + scheduler_inputs: Optional[TensorDict | dict[str, torch.Tensor]], + step: int, + ): + assert scheduler_inputs is not None + latents = scheduler_inputs["all_latents"] + timesteps = scheduler_inputs["all_timesteps"] + + noise_pred = module(**model_inputs)[0] + + # CFG during training (if configured) + true_cfg_scale = model_config.extra_configs.get("true_cfg_scale", 1.0) + if true_cfg_scale > 1.0: + assert negative_model_inputs is not None + neg_noise_pred = module(**negative_model_inputs)[0] + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + _, log_prob, prev_sample_mean, std_dev_t = scheduler.sample_previous_step( + sample=latents[:, step].float(), + model_output=noise_pred.float(), + timestep=timesteps[:, step], + noise_level=model_config.extra_configs.get("noise_level", None), + prev_sample=latents[:, step + 1].float(), + sde_type=model_config.extra_configs.get("sde_type", None), + return_logprobs=True, + ) + return log_prob, prev_sample_mean, std_dev_t diff --git a/examples/flowgrpo_trainer/diffusers/bagel_model.py b/examples/flowgrpo_trainer/diffusers/bagel_model.py index c8e2ee70aeb..e8aa3b63227 100644 --- a/examples/flowgrpo_trainer/diffusers/bagel_model.py +++ b/examples/flowgrpo_trainer/diffusers/bagel_model.py @@ -58,6 +58,14 @@ def head_dim(self) -> int: def patch_latent_dim(self) -> int: return self.latent_patch_size ** 2 * self.latent_channel + def save_pretrained(self, save_directory: str): + """Save config as JSON (compatible with diffusers checkpoint manager).""" + from dataclasses import asdict + output_path = os.path.join(save_directory, "config.json") + os.makedirs(save_directory, exist_ok=True) + with open(output_path, "w") as f: + json.dump(asdict(self), f, indent=4, sort_keys=True) + @classmethod def from_model_path(cls, model_path: str) -> "BagelTrainingConfig": cfg_path = os.path.join(model_path, "config.json") @@ -751,6 +759,16 @@ def forward( return (velocity,) + # ------------------------------------------------------------------ + # PEFT / LoRA compatibility + # ------------------------------------------------------------------ + + def add_adapter(self, adapter_config, adapter_name: str = "default"): + """Add a PEFT LoRA adapter (matches diffusers.ModelMixin API).""" + from peft import inject_adapter_in_model + + inject_adapter_in_model(adapter_config, self, adapter_name) + # ------------------------------------------------------------------ # Checkpoint loading # ------------------------------------------------------------------ diff --git a/examples/flowgrpo_trainer/prepare_ocr_data.py b/examples/flowgrpo_trainer/prepare_ocr_data.py index e129fdec203..60a8a1369b1 100644 --- a/examples/flowgrpo_trainer/prepare_ocr_data.py +++ b/examples/flowgrpo_trainer/prepare_ocr_data.py @@ -19,15 +19,17 @@ import pandas as pd -TEMPLATES = [ - "Generate an image that clearly displays the text: '{text}'", - "Create a picture with the word '{text}' written on it", - "Produce an image containing the text '{text}' in a readable font", - "Design an image where the text '{text}' is prominently shown", - "Make an image with '{text}' written clearly in the center", - "Generate a clean image that shows the text: '{text}'", - "Create a visually clear image displaying '{text}'", - "Render an image with the following text: '{text}'", +SYSTEM_PROMPT = ( + "Describe the image by detailing the color, shape, size, " + "texture, quantity, text, spatial relationships of the objects and background:" +) + +TEMPLATES_WORD = [ + "Create a picture with the word '{text}' written on it.", +] + +TEMPLATES_PHRASE = [ + "Create a picture with the sentence '{text}' written on it.", ] # Simple words/phrases of varying difficulty @@ -81,12 +83,20 @@ def generate_samples(n: int, seed: int = 42) -> list[dict]: else: text = random_alphanum(rng) - template = rng.choice(TEMPLATES) + is_phrase = " " in text + template = rng.choice(TEMPLATES_PHRASE if is_phrase else TEMPLATES_WORD) prompt_text = template.format(text=text) sample = { "data_source": "ocr", - "prompt": [{"role": "user", "content": prompt_text}], + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt_text}, + ], + "negative_prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": " "}, + ], "reward_model": {"style": "rule", "ground_truth": text}, } samples.append(sample) diff --git a/examples/flowgrpo_trainer/reward_fn.py b/examples/flowgrpo_trainer/reward_fn.py index a9f6bcf4a4a..ab3b1d720f0 100644 --- a/examples/flowgrpo_trainer/reward_fn.py +++ b/examples/flowgrpo_trainer/reward_fn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import json import os from typing import Optional @@ -27,14 +28,20 @@ async def chat_complete(router_address: str, chat_complete_request: dict): url = f"http://{router_address}/v1/chat/completions" try: - timeout = aiohttp.ClientTimeout(total=None) + timeout = aiohttp.ClientTimeout(total=120) session = aiohttp.ClientSession(timeout=timeout) async with session.post(url, json=chat_complete_request) as resp: output = await resp.text() + if not output or not output.strip(): + return None output = json.loads(output) return ChatCompletion(**output) + except (json.JSONDecodeError, aiohttp.ClientError, asyncio.TimeoutError) as e: + print(f"[reward_fn] chat_complete failed: {type(e).__name__}: {e}") + return None except Exception as e: - raise e + print(f"[reward_fn] chat_complete unexpected error: {type(e).__name__}: {e}") + return None finally: await session.close() @@ -127,7 +134,11 @@ async def compute_score_ocr( router_address=reward_router_address, chat_complete_request=chat_complete_request, ) + if result is None or not result.choices: + return 0.0 grm_response = result.choices[0].message.content + if not grm_response: + return 0.0 # compute OCR score text = grm_response diff --git a/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh b/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh index 18f7e9be6d2..e9c586ad1c7 100644 --- a/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh +++ b/examples/flowgrpo_trainer/run_bagel_flowgrpo.sh @@ -1,18 +1,17 @@ # Bagel LoRA RL, vllm_omni rollout (FlowGRPO) # # Prerequisites: -# 1. A Bagel model (e.g. BAGEL-8B-MoT) at $BAGEL_MODEL_PATH -# 2. A stage config JSON at $BAGEL_STAGE_CONFIG that describes multi-GPU -# placement (thinker on GPU 0, DiT on GPU 1, etc.) -# 3. A DiffusionModelBase implementation registered for Bagel's architecture -# at examples/flowgrpo_trainer/diffusers/bagel.py (see qwen_image.py for reference) +# 1. A Bagel model (e.g. BAGEL-7B-MoT) at $BAGEL_MODEL_PATH +# 2. A stage config YAML at $BAGEL_STAGE_CONFIG for vllm-omni +# 3. DiffusionModelBase registered as "OmniBagelForConditionalGeneration" +# at examples/flowgrpo_trainer/diffusers/bagel.py # 4. A reward VLM model at $REWARD_MODEL_PATH +# 5. OCR training data at $OCR_TRAIN_PATH / $OCR_TEST_PATH +# (generate via: python examples/flowgrpo_trainer/prepare_ocr_data.py) # # Usage: -# # Minimal (set paths via env vars): -# export BAGEL_MODEL_PATH=$HOME/models/BAGEL-8B-MoT -# export BAGEL_STAGE_CONFIG=$HOME/models/BAGEL-8B-MoT/stage_configs.json -# export REWARD_MODEL_PATH=$HOME/models/Qwen/Qwen3-VL-8B-Instruct +# export BAGEL_MODEL_PATH=/path/to/BAGEL-7B-MoT +# export REWARD_MODEL_PATH=/path/to/Qwen3-VL-8B-Instruct # bash examples/flowgrpo_trainer/run_bagel_flowgrpo.sh # # # Override any param via CLI: @@ -25,8 +24,7 @@ BAGEL_MODEL_PATH=${BAGEL_MODEL_PATH:-$HOME/models/BAGEL-7B-MoT} SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BAGEL_STAGE_CONFIG=${BAGEL_STAGE_CONFIG:-$SCRIPT_DIR/bagel_stage_config.yaml} - -REWARD_MODEL_PATH=${REWARD_MODEL_PATH:-$HOME/models/Qwen/Qwen3-VL-8B-Instruct} +REWARD_MODEL_PATH=${REWARD_MODEL_PATH:-$HOME/models/Qwen3-VL-8B-Instruct} ocr_train_path=${OCR_TRAIN_PATH:-$HOME/data/ocr/train.parquet} ocr_test_path=${OCR_TEST_PATH:-$HOME/data/ocr/test.parquet} @@ -40,7 +38,7 @@ python3 -m verl.trainer.main_flowgrpo \ algorithm.adv_estimator=flow_grpo \ data.train_files=$ocr_train_path \ data.val_files=$ocr_test_path \ - data.train_batch_size=32 \ + data.train_batch_size=16 \ data.max_prompt_length=256 \ data.trust_remote_code=True \ actor_rollout_ref.model.path=$BAGEL_MODEL_PATH \ @@ -48,26 +46,30 @@ python3 -m verl.trainer.main_flowgrpo \ +actor_rollout_ref.model.architecture=OmniBagelForConditionalGeneration \ actor_rollout_ref.model.trust_remote_code=True \ actor_rollout_ref.model.external_lib="examples.flowgrpo_trainer.diffusers.bagel" \ + actor_rollout_ref.model.height=512 \ + actor_rollout_ref.model.width=512 \ + actor_rollout_ref.model.num_inference_steps=15 \ actor_rollout_ref.model.lora_rank=64 \ actor_rollout_ref.model.lora_alpha=128 \ actor_rollout_ref.model.target_modules="['q_proj_moe_gen','k_proj_moe_gen','v_proj_moe_gen','o_proj_moe_gen']" \ - actor_rollout_ref.actor.optim.lr=3e-4 \ + actor_rollout_ref.actor.optim.lr=1e-3 \ actor_rollout_ref.actor.optim.weight_decay=0.0001 \ actor_rollout_ref.actor.ppo_mini_batch_size=16 \ - actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ - actor_rollout_ref.actor.fsdp_config.param_offload=True \ - actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \ actor_rollout_ref.actor.policy_loss.loss_mode=flow_grpo \ - actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ actor_rollout_ref.rollout.name=$ENGINE \ - actor_rollout_ref.rollout.n=16 \ + actor_rollout_ref.rollout.n=8 \ actor_rollout_ref.rollout.agent.default_agent_loop=diffusion_single_turn_agent \ - actor_rollout_ref.rollout.agent.num_workers=4 \ + actor_rollout_ref.rollout.agent.num_workers=2 \ actor_rollout_ref.rollout.load_format=auto \ actor_rollout_ref.rollout.layered_summon=True \ - actor_rollout_ref.rollout.val_kwargs.num_inference_steps=50 \ + actor_rollout_ref.rollout.num_inference_steps=15 \ + actor_rollout_ref.rollout.val_kwargs.num_inference_steps=15 \ +actor_rollout_ref.rollout.extra_configs.noise_level=1.2 \ +actor_rollout_ref.rollout.extra_configs.sde_type="sde" \ +actor_rollout_ref.rollout.extra_configs.sde_window_size=2 \ @@ -76,10 +78,10 @@ python3 -m verl.trainer.main_flowgrpo \ +actor_rollout_ref.rollout.val_kwargs.extra_configs.noise_level=0.0 \ +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.custom_pipeline=examples.flowgrpo_trainer.vllm_omni.pipeline_bagel.BagelPipelineWithLogProb \ +actor_rollout_ref.rollout.engine_kwargs.vllm_omni.stage_configs_path=$BAGEL_STAGE_CONFIG \ - actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ - reward.num_workers=4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + reward.num_workers=1 \ reward.reward_manager.name=visual \ - reward.reward_model.enable=False \ + reward.reward_model.enable=True \ reward.reward_model.model_path=$REWARD_MODEL_PATH \ reward.reward_model.rollout.name=$REWARD_ENGINE \ reward.reward_model.rollout.tensor_model_parallel_size=4 \ @@ -89,11 +91,11 @@ python3 -m verl.trainer.main_flowgrpo \ trainer.logger='["console", "wandb"]' \ trainer.project_name=flow_grpo \ trainer.experiment_name=bagel_ocr_lora \ - trainer.log_val_generations=8 \ + trainer.log_val_generations=4 \ trainer.val_before_train=False \ - trainer.n_gpus_per_node=4 \ + trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ - trainer.save_freq=30 \ - trainer.test_freq=30 \ - trainer.total_epochs=15 \ - trainer.total_training_steps=300 $@ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.total_epochs=5 \ + trainer.total_training_steps=50 $@ diff --git a/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py b/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py index 3f0764e52ae..cbd4f68db75 100644 --- a/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py +++ b/examples/flowgrpo_trainer/scheduler/scheduling_flow_match_sde_discrete.py @@ -51,6 +51,14 @@ class FlowMatchSDEDiscreteScheduler(FlowMatchEulerDiscreteScheduler): and diffusers v0.37 branch. """ + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + # Use nearest-neighbor matching to avoid float32/float64 precision mismatch + # between timesteps stored during rollout and scheduler's internal timesteps. + diffs = (schedule_timesteps - timestep).abs() + return diffs.argmin().item() + def step( self, model_output: torch.FloatTensor, @@ -169,11 +177,14 @@ def sample_previous_step( sigma = self.sigmas[sigma_idx] sigma_prev = self.sigmas[sigma_idx + 1] else: - sigma_idx = torch.tensor([self.index_for_timestep(t) for t in timestep]) + sigma_idx = torch.tensor([self.index_for_timestep(t.cpu()) for t in timestep]) sigma = self.sigmas[sigma_idx].view(-1, *([1] * (len(sample.shape) - 1))) sigma_prev = self.sigmas[sigma_idx + 1].view(-1, *([1] * (len(sample.shape) - 1))) - sigma_max = self.sigmas[1] + # Move scheduler tensors to the same device as sample + sigma = sigma.to(device=sample.device, dtype=sample.dtype) + sigma_prev = sigma_prev.to(device=sample.device, dtype=sample.dtype) + sigma_max = self.sigmas[1].to(device=sample.device, dtype=sample.dtype) dt = sigma_prev - sigma if sde_type == "sde": diff --git a/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py b/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py index 7184f46949d..a4f49764f37 100644 --- a/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py +++ b/examples/flowgrpo_trainer/vllm_omni/pipeline_bagel.py @@ -45,11 +45,13 @@ logger = logging.getLogger(__name__) -def _maybe_to_cpu(v): +def _to_cpu_tensor(v): + """Convert to a single CPU tensor, stacking a list of tensors if needed.""" if isinstance(v, torch.Tensor): return v.detach().cpu() if isinstance(v, list): - return [_maybe_to_cpu(x) for x in v] + tensors = [x.detach().cpu() if isinstance(x, torch.Tensor) else torch.tensor(x) for x in v] + return torch.stack(tensors) if tensors else None return v @@ -134,14 +136,14 @@ def forward(self, req: OmniDiffusionRequest) -> DiffusionOutput: output = super().forward(req) - # Enrich custom_output with RL-specific fields + # Enrich custom_output with RL-specific fields (must be tensors for batch stacking) custom = output.custom_output or {} if output.trajectory_latents is not None: - custom["all_latents"] = _maybe_to_cpu(output.trajectory_latents) + custom["all_latents"] = _to_cpu_tensor(output.trajectory_latents) if output.trajectory_timesteps is not None: - custom["all_timesteps"] = _maybe_to_cpu(output.trajectory_timesteps) + custom["all_timesteps"] = _to_cpu_tensor(output.trajectory_timesteps) if output.trajectory_log_probs is not None: - custom["all_log_probs"] = _maybe_to_cpu(output.trajectory_log_probs) + custom["all_log_probs"] = _to_cpu_tensor(output.trajectory_log_probs) output.custom_output = custom return output diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 6b9503ac181..f33b1615777 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -165,7 +165,7 @@ async def generate( try: output: TokenOutput | DiffusionOutput = await server.generate.remote( request_id=uuid4().hex, # use new request_id for each turn - prompt_ids=prompt_ids, + prompt_token_ids=prompt_ids, sampling_params=sampling_params, image_data=image_data, video_data=video_data, diff --git a/verl/models/diffusers_model/base.py b/verl/models/diffusers_model/base.py index 33a7bcfcc44..63ab9857034 100644 --- a/verl/models/diffusers_model/base.py +++ b/verl/models/diffusers_model/base.py @@ -85,6 +85,16 @@ def get_class(cls, model_config: DiffusionModelConfig) -> type["DiffusionModelBa f"Set ``external_lib`` in DiffusionModelConfig to load your implementation." ) from None + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: + """Optional hook for custom model loading. + + Override this to load non-standard models (e.g. models not loadable + via ``diffusers.AutoModel``). Return ``None`` to fall back to the + default ``AutoModel.from_pretrained`` path in the FSDP engine. + """ + return None + @classmethod @abstractmethod def build_scheduler(cls, model_config: DiffusionModelConfig) -> SchedulerMixin: diff --git a/verl/trainer/config/diffusion_trainer.yaml b/verl/trainer/config/diffusion_trainer.yaml index 656a8172ecd..d480c6c5498 100644 --- a/verl/trainer/config/diffusion_trainer.yaml +++ b/verl/trainer/config/diffusion_trainer.yaml @@ -177,6 +177,89 @@ trainer: # mode: "auto", "enable", or "disable" use_legacy_worker_impl: disable +# profiler configs +global_profiler: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.ProfilerConfig + + # Profiling tool: choose between nsys, npu, torch, torch_memory, precision_debugger + tool: null + + # profile steps + steps: null + + # Whether to combine continuous steps into one database. + profile_continuous_steps: False + + # Path to save profiling contents + save_path: "outputs/profile" + + # Specific tool configs + global_tool_config: + + # nsys config + nsys: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.NsightToolConfig + + # True for each task has its own database, False for all tasks in one training step share one database. + discrete: False + + # controller Nvidia Nsight Systems Options. + controller_nsight_options: + + trace: "cuda,nvtx,cublas,ucx" + + cuda-memory-usage: "true" + + cuda-graph-trace: "graph" + + # worker Nvidia Nsight Systems Options. + worker_nsight_options: + + trace: "cuda,nvtx,cublas,ucx" + + cuda-memory-usage: "true" + + cuda-graph-trace: "graph" + + capture-range: "cudaProfilerApi" + + capture-range-end: null + + kill: none + + # enable memory visualization for debugging memory usage + torch_memory: + + _target_: verl.utils.profiler.config.TorchMemoryToolConfig + + trace_alloc_max_entries: 100_000 + + stack_depth: 32 + + context: "all" + + stacks: "all" + + kw_args: {} + + # msprobe precision debugger + precision_debugger: + + # Required when using verl.utils.omega_conf_to_dataclass to instantiate dataclass configs + _target_: verl.utils.profiler.config.PrecisionDebuggerToolConfig + + config_path: null + + steps: null + + stages: null + + strict: False + # configs related to ray ray_kwargs: diff --git a/verl/trainer/diffusion/ray_diffusion_trainer.py b/verl/trainer/diffusion/ray_diffusion_trainer.py index fe0be9d15c1..dd505060bc1 100644 --- a/verl/trainer/diffusion/ray_diffusion_trainer.py +++ b/verl/trainer/diffusion/ray_diffusion_trainer.py @@ -588,6 +588,16 @@ def init_workers(self): wg_kwargs = {} # Setting up kwargs for RayWorkerGroup if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config, "global_profiler.steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config, "global_profiler.steps") + if OmegaConf.select(self.config, "global_profiler.tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) wg_kwargs["device_name"] = self.device_name for resource_pool, class_dict in self.resource_pool_to_cls.items(): @@ -666,6 +676,20 @@ def init_workers(self): # sleep all replicas to load checkpoint self.checkpoint_manager.sleep_replicas() + def _start_profiling(self, do_profile: bool) -> None: + """Start profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.start_profile(role="e2e", profile_step=self.global_steps) + if self.use_reference_policy: + self.ref_policy_wg.start_profile(profile_step=self.global_steps) + + def _stop_profiling(self, do_profile: bool) -> None: + """Stop profiling for all worker groups if profiling is enabled.""" + if do_profile: + self.actor_rollout_wg.stop_profile() + if self.use_reference_policy: + self.ref_policy_wg.stop_profile() + def _save_checkpoint(self): from verl.utils.fs import local_mkdir_safe @@ -879,6 +903,7 @@ def fit(self): self.global_steps += 1 last_val_metrics = None self.max_steps_duration = 0 + prev_step_profile = False for epoch in range(current_epoch, self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: @@ -903,7 +928,21 @@ def fit(self): ) is_last_step = self.global_steps >= self.total_training_steps + + profile_steps = ( + self.config.global_profiler.steps + if OmegaConf.select(self.config, "global_profiler.steps") is not None + else None + ) + curr_step_profile = self.global_steps in profile_steps if profile_steps else False + profile_continuous = OmegaConf.select(self.config, "global_profiler.profile_continuous_steps") or False + with marked_timer("step", timing_raw): + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile if profile_continuous else curr_step_profile + ) + # generate a batch with marked_timer("gen", timing_raw, color="red"): gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output) @@ -1016,6 +1055,15 @@ def fit(self): with marked_timer("update_weights", timing_raw, color="red"): self.checkpoint_manager.update_weights(self.global_steps) + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in profile_steps if profile_steps else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile if profile_continuous else curr_step_profile + ) + prev_step_profile = curr_step_profile + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/workers/engine/fsdp/diffusers_impl.py b/verl/workers/engine/fsdp/diffusers_impl.py index 47dcebeefb8..a1eb5fad8f0 100644 --- a/verl/workers/engine/fsdp/diffusers_impl.py +++ b/verl/workers/engine/fsdp/diffusers_impl.py @@ -163,8 +163,6 @@ def _init_device_mesh(self): raise NotImplementedError("Ulysses sequence parallel for Diffusers backend is not supported currently.") def _build_module(self): - from diffusers import AutoModel - from verl.utils.torch_dtypes import PrecisionType torch_dtype = self.engine_config.model_dtype @@ -175,6 +173,21 @@ def _build_module(self): torch_dtype = PrecisionType.to_dtype(torch_dtype) + # Allow registered DiffusionModelBase subclass to provide custom loading + from verl.models.diffusers_model import DiffusionModelBase + + model_cls = DiffusionModelBase.get_class(self.model_config) + module = model_cls.build_module(self.model_config, torch_dtype) + + if module is not None: + module.to(torch_dtype) + if not hasattr(module, "can_generate"): + module.can_generate = lambda: False + return module + + # Default path: load via diffusers AutoModel + from diffusers import AutoModel + init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): @@ -529,12 +542,12 @@ def prepare_model_inputs(self, micro_batch: TensorDict, step: int): """ latents = micro_batch["all_latents"] timesteps = micro_batch["all_timesteps"] - prompt_embeds = micro_batch["prompt_embeds"] - prompt_embeds_mask = micro_batch["prompt_embeds_mask"] - negative_prompt_embeds = micro_batch["negative_prompt_embeds"] - negative_prompt_embeds_mask = micro_batch["negative_prompt_embeds_mask"] + prompt_embeds = micro_batch.get("prompt_embeds", None) + prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None) + negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) + negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) - if prompt_embeds.is_nested: + if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: diff --git a/verl/workers/utils/padding.py b/verl/workers/utils/padding.py index 73ad94fec59..ca62bf980e9 100644 --- a/verl/workers/utils/padding.py +++ b/verl/workers/utils/padding.py @@ -173,7 +173,8 @@ def _to_nested(embeds: torch.Tensor, mask: torch.Tensor): torch.nested.as_nested_tensor(mask_list, layout=torch.jagged), ) - data["prompt_embeds"], data["prompt_embeds_mask"] = _to_nested(data["prompt_embeds"], data["prompt_embeds_mask"]) + if isinstance(data.get("prompt_embeds", None), torch.Tensor): + data["prompt_embeds"], data["prompt_embeds_mask"] = _to_nested(data["prompt_embeds"], data["prompt_embeds_mask"]) if isinstance(data.get("negative_prompt_embeds", None), torch.Tensor): data["negative_prompt_embeds"], data["negative_prompt_embeds_mask"] = _to_nested(