diff --git a/verl_omni/agent_loop/diffusion_agent_loop.py b/verl_omni/agent_loop/diffusion_agent_loop.py index f470895a..a69f882a 100644 --- a/verl_omni/agent_loop/diffusion_agent_loop.py +++ b/verl_omni/agent_loop/diffusion_agent_loop.py @@ -235,10 +235,11 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA padding="max_length", max_length=self.rollout_config.prompt_length, return_tensors="pt", - return_attention_mask=False, + return_attention_mask=True, ) if prompt_output["input_ids"].dim() == 1: prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0) + prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0) response_diffusion_output = output.response_diffusion_output.unsqueeze(0) @@ -247,6 +248,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA response_logprobs = output.response_logprobs.unsqueeze(0) prompt_ids = prompt_output["input_ids"] + extra_fields["attention_mask"] = prompt_output["attention_mask"] await self._compute_score( output, diff --git a/verl_omni/pipelines/model_base.py b/verl_omni/pipelines/model_base.py index 378459f0..a13e741e 100644 --- a/verl_omni/pipelines/model_base.py +++ b/verl_omni/pipelines/model_base.py @@ -76,6 +76,15 @@ def get_class(cls, model_config: DiffusionModelConfig) -> type["DiffusionModelBa f"Set ``external_lib`` in DiffusionModelConfig to load your implementation." ) from None + @classmethod + def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: + """Load the model without ``diffusers.AutoModel``. + + Return ``None`` to use the default ``AutoModel`` path. + Override this for models that diffusers cannot load. + """ + return None + @classmethod @abstractmethod def build_scheduler(cls, model_config: DiffusionModelConfig) -> SchedulerMixin: diff --git a/verl_omni/trainer/config/_generated_diffusion_trainer.yaml b/verl_omni/trainer/config/_generated_diffusion_trainer.yaml index bb92285a..67c757f0 100644 --- a/verl_omni/trainer/config/_generated_diffusion_trainer.yaml +++ b/verl_omni/trainer/config/_generated_diffusion_trainer.yaml @@ -284,6 +284,8 @@ actor_rollout_ref: policy_state_adapters: - default lora_dtype: null + fsdp_layer_prefixes: + - transformer_blocks. mtp: _target_: verl.workers.config.MtpConfig enable: false diff --git a/verl_omni/trainer/config/diffusion/model/diffusion_model.yaml b/verl_omni/trainer/config/diffusion/model/diffusion_model.yaml index 31aaaa9b..f5d2ece1 100644 --- a/verl_omni/trainer/config/diffusion/model/diffusion_model.yaml +++ b/verl_omni/trainer/config/diffusion/model/diffusion_model.yaml @@ -62,6 +62,9 @@ policy_state_adapters: ["default"] # Default null means no conversion. lora_dtype: null +# FSDP layer name prefixes for LoRA parameter layered summon. +fsdp_layer_prefixes: ["transformer_blocks."] + # MTP mtp: diff --git a/verl_omni/utils/fsdp_utils.py b/verl_omni/utils/fsdp_utils.py index 5564b148..65be3917 100644 --- a/verl_omni/utils/fsdp_utils.py +++ b/verl_omni/utils/fsdp_utils.py @@ -16,8 +16,9 @@ """ from collections import OrderedDict -from collections.abc import Callable +from collections.abc import Callable, Sequence from contextlib import ExitStack, contextmanager +from functools import partial from peft.utils.save_and_load import get_peft_model_state_dict from verl.utils.fsdp_utils import collect_lora_params as _upstream_collect_lora_params @@ -167,8 +168,17 @@ def _collect_lora_params_with_adapter( return _collect_base_weights_to_cpu(peft_model) -def _layered_summon_lora_params_diffusers(fsdp_module, adapter_name: str = "default") -> OrderedDict: - """Layered LoRA param collection for diffusers transformer-block models.""" +def _layered_summon_lora_params_diffusers( + fsdp_module, adapter_name: str = "default", layer_prefixes: Sequence[str] = ("transformer_blocks.",) +) -> OrderedDict: + """Layered LoRA param collection for diffusers transformer-block models. + + Args: + fsdp_module: The FSDP-wrapped module. + adapter_name: LoRA adapter name. + layer_prefixes: FSDP layer name prefixes. Defaults to + ``["transformer_blocks."]``. + """ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from verl.utils.device import get_torch_device @@ -178,12 +188,12 @@ def _prefix_submodules(module, prefix): yield name, submodule lora_params = OrderedDict() - prefix_list = [ + prefix_list = [] + for lp in layer_prefixes: # FSDP1 - "_fsdp_wrapped_module.transformer_blocks.", + prefix_list.append(f"_fsdp_wrapped_module.{lp}") # FSDP2 - "transformer_blocks.", - ] + prefix_list.append(lp) peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) for prefix in prefix_list: for name, submodule in _prefix_submodules(fsdp_module, prefix): @@ -211,17 +221,42 @@ def collect_lora_params( base_sync_done: bool, is_diffusers: bool = False, adapter_name: str = "default", + layer_prefixes: Sequence[str] = ("transformer_blocks.",), ) -> OrderedDict: - """Extended version of ``verl.utils.fsdp_utils.collect_lora_params``.""" + """Collect LoRA or base parameters for weight sync to the rollout worker. + + Raises ``RuntimeError`` when no parameters were collected + (e.g. mismatched ``layer_prefixes``). + + Args: + module: The FSDP-wrapped or plain module. + layered_summon: Summon one FSDP unit at a time instead of the full model. + base_sync_done: If ``True``, collect only LoRA weights; else full base weights. + is_diffusers: Use the diffusers-specific layered summon helper. + adapter_name: LoRA adapter name (usually ``"default"``). + layer_prefixes: FSDP layer name prefixes (``["transformer_blocks."]`` + """ use_diffusers_layered = is_diffusers and layered_summon and fsdp_version(module) > 0 if adapter_name == "default" and not use_diffusers_layered and fsdp_version(module) != 2: return _upstream_collect_lora_params(module, layered_summon=layered_summon, base_sync_done=base_sync_done) - layered_summon_fn = _layered_summon_lora_params_diffusers if is_diffusers else _upstream_layered_summon_lora_params - return _collect_lora_params_with_adapter( + if is_diffusers: + layered_summon_fn = partial( + _layered_summon_lora_params_diffusers, adapter_name=adapter_name, layer_prefixes=layer_prefixes + ) + else: + layered_summon_fn = _upstream_layered_summon_lora_params + lora_params = _collect_lora_params_with_adapter( module, layered_summon=layered_summon, base_sync_done=base_sync_done, adapter_name=adapter_name, layered_summon_fn=layered_summon_fn, ) + if not lora_params: + raise RuntimeError( + f"collect_lora_params collected 0 parameters with prefixes={layer_prefixes}. " + "Check ``fsdp_layer_prefixes`` in the model config matches the model's " + "FSDP layer naming (e.g. ``['transformer_blocks.']`` for DiT models)." + ) + return lora_params diff --git a/verl_omni/workers/config/diffusion/model.py b/verl_omni/workers/config/diffusion/model.py index d405876a..a85c5404 100644 --- a/verl_omni/workers/config/diffusion/model.py +++ b/verl_omni/workers/config/diffusion/model.py @@ -95,6 +95,8 @@ class DiffusionModelConfig(BaseConfig): algo: Optional[DiffusionRolloutAlgoConfig] = field(default_factory=DiffusionRolloutAlgoConfig) + fsdp_layer_prefixes: list[str] = field(default_factory=lambda: ["transformer_blocks."]) + def __post_init__(self): import_external_libs(self.external_lib) diff --git a/verl_omni/workers/engine/fsdp/diffusers_impl.py b/verl_omni/workers/engine/fsdp/diffusers_impl.py index 50b20991..139efa14 100644 --- a/verl_omni/workers/engine/fsdp/diffusers_impl.py +++ b/verl_omni/workers/engine/fsdp/diffusers_impl.py @@ -189,6 +189,52 @@ def _init_device_mesh(self): self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + def _build_module_from_registry(self, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]: + """Try loading via ``DiffusionModelBase.build_module()``. + + Returns ``None`` if the registry has no custom loader, so the + caller falls back to ``diffusers.AutoModel``. + """ + from verl_omni.pipelines.model_base import DiffusionModelBase + + model_cls = DiffusionModelBase.get_class(self.model_config) + module = model_cls.build_module(self.model_config, torch_dtype) + if module is None: + return None + + logger.warning( + "Built %s via DiffusionModelBase custom loader; engine-level hooks " + "(attention processors, gradient-checkpointing wrappers, LoRA, " + "dtype upcast) may be partially effective or silently inactive. " + "See the docstring of _build_module_from_registry.", + type(module).__name__, + ) + + try: + module.to(torch_dtype) + except AttributeError: + raise TypeError( + f"{type(module).__name__} returned by build_module() has no to() method. " + "Custom models must be torch.nn.Module instances." + ) from None + + if self.model_config.enable_gradient_checkpointing: + try: + module.enable_gradient_checkpointing() + except AttributeError: + raise NotImplementedError( + f"Gradient checkpointing is enabled in config, but {type(module).__name__} " + "does not implement enable_gradient_checkpointing(). " + "Either implement it or set enable_gradient_checkpointing=False." + ) from None + logger.info( + "Gradient checkpointing enabled on %s via enable_gradient_checkpointing().", + type(module).__name__, + ) + + module.can_generate = lambda: False + return module + def _build_module(self): from diffusers import AutoModel from verl.utils.torch_dtypes import PrecisionType @@ -201,6 +247,11 @@ def _build_module(self): torch_dtype = PrecisionType.to_dtype(torch_dtype) + module = self._build_module_from_registry(torch_dtype) + if module is not None: + return module + + # Default path: load via diffusers AutoModel init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh) with init_context(), warnings.catch_warnings(): @@ -668,6 +719,7 @@ def get_per_tensor_param( base_sync_done=base_sync_done, is_diffusers=True, adapter_name=adapter_name or "default", + layer_prefixes=self.model_config.fsdp_layer_prefixes, ) if not base_sync_done: params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()} @@ -760,16 +812,16 @@ def prepare_model_inputs(self, micro_batch: TensorDict, step: int): """ latents = micro_batch["all_latents"] timesteps = micro_batch["all_timesteps"] - prompt_embeds = micro_batch["prompt_embeds"] - prompt_embeds_mask = micro_batch["prompt_embeds_mask"] - negative_prompt_embeds = micro_batch["negative_prompt_embeds"] - negative_prompt_embeds_mask = micro_batch["negative_prompt_embeds_mask"] + prompt_embeds = micro_batch.get("prompt_embeds", None) + prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None) + negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None) + negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None) sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1 - if prompt_embeds.is_nested: + if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested: prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask) - if sp_size > 1: + if isinstance(prompt_embeds, torch.Tensor) and sp_size > 1: prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size) if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested: