Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion verl_omni/agent_loop/diffusion_agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,11 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA
padding="max_length",
max_length=self.rollout_config.prompt_length,
return_tensors="pt",
return_attention_mask=False,
return_attention_mask=True,
)
if prompt_output["input_ids"].dim() == 1:
prompt_output["input_ids"] = prompt_output["input_ids"].unsqueeze(0)
prompt_output["attention_mask"] = prompt_output["attention_mask"].unsqueeze(0)

response_diffusion_output = output.response_diffusion_output.unsqueeze(0)

Expand All @@ -247,6 +248,7 @@ async def _agent_loop_postprocess(self, output, **kwargs) -> _InternalDiffusionA
response_logprobs = output.response_logprobs.unsqueeze(0)

prompt_ids = prompt_output["input_ids"]
extra_fields["attention_mask"] = prompt_output["attention_mask"]

await self._compute_score(
output,
Expand Down
9 changes: 9 additions & 0 deletions verl_omni/pipelines/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ def get_class(cls, model_config: DiffusionModelConfig) -> type["DiffusionModelBa
f"Set ``external_lib`` in DiffusionModelConfig to load your implementation."
) from None

@classmethod
def build_module(cls, model_config: DiffusionModelConfig, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]:
"""Load the model without ``diffusers.AutoModel``.

Return ``None`` to use the default ``AutoModel`` path.
Override this for models that diffusers cannot load.
"""
return None

@classmethod
@abstractmethod
def build_scheduler(cls, model_config: DiffusionModelConfig) -> SchedulerMixin:
Expand Down
2 changes: 2 additions & 0 deletions verl_omni/trainer/config/_generated_diffusion_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ actor_rollout_ref:
policy_state_adapters:
- default
lora_dtype: null
fsdp_layer_prefixes:
- transformer_blocks.
mtp:
_target_: verl.workers.config.MtpConfig
enable: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ policy_state_adapters: ["default"]
# Default null means no conversion.
lora_dtype: null

# FSDP layer name prefixes for LoRA parameter layered summon.
fsdp_layer_prefixes: ["transformer_blocks."]

# MTP
mtp:

Expand Down
55 changes: 45 additions & 10 deletions verl_omni/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
"""

from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Callable, Sequence
from contextlib import ExitStack, contextmanager
from functools import partial

from peft.utils.save_and_load import get_peft_model_state_dict
from verl.utils.fsdp_utils import collect_lora_params as _upstream_collect_lora_params
Expand Down Expand Up @@ -167,8 +168,17 @@ def _collect_lora_params_with_adapter(
return _collect_base_weights_to_cpu(peft_model)


def _layered_summon_lora_params_diffusers(fsdp_module, adapter_name: str = "default") -> OrderedDict:
"""Layered LoRA param collection for diffusers transformer-block models."""
def _layered_summon_lora_params_diffusers(
fsdp_module, adapter_name: str = "default", layer_prefixes: Sequence[str] = ("transformer_blocks.",)
) -> OrderedDict:
"""Layered LoRA param collection for diffusers transformer-block models.

Args:
fsdp_module: The FSDP-wrapped module.
adapter_name: LoRA adapter name.
layer_prefixes: FSDP layer name prefixes. Defaults to
``["transformer_blocks."]``.
"""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl.utils.device import get_torch_device

Expand All @@ -178,12 +188,12 @@ def _prefix_submodules(module, prefix):
yield name, submodule

lora_params = OrderedDict()
prefix_list = [
prefix_list = []
for lp in layer_prefixes:
# FSDP1
"_fsdp_wrapped_module.transformer_blocks.",
prefix_list.append(f"_fsdp_wrapped_module.{lp}")
# FSDP2
"transformer_blocks.",
]
prefix_list.append(lp)
peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module)
for prefix in prefix_list:
for name, submodule in _prefix_submodules(fsdp_module, prefix):
Expand Down Expand Up @@ -211,17 +221,42 @@ def collect_lora_params(
base_sync_done: bool,
is_diffusers: bool = False,
adapter_name: str = "default",
layer_prefixes: Sequence[str] = ("transformer_blocks.",),
) -> OrderedDict:
"""Extended version of ``verl.utils.fsdp_utils.collect_lora_params``."""
"""Collect LoRA or base parameters for weight sync to the rollout worker.

Raises ``RuntimeError`` when no parameters were collected
(e.g. mismatched ``layer_prefixes``).

Args:
module: The FSDP-wrapped or plain module.
layered_summon: Summon one FSDP unit at a time instead of the full model.
base_sync_done: If ``True``, collect only LoRA weights; else full base weights.
is_diffusers: Use the diffusers-specific layered summon helper.
adapter_name: LoRA adapter name (usually ``"default"``).
layer_prefixes: FSDP layer name prefixes (``["transformer_blocks."]``
"""
use_diffusers_layered = is_diffusers and layered_summon and fsdp_version(module) > 0
if adapter_name == "default" and not use_diffusers_layered and fsdp_version(module) != 2:
return _upstream_collect_lora_params(module, layered_summon=layered_summon, base_sync_done=base_sync_done)

layered_summon_fn = _layered_summon_lora_params_diffusers if is_diffusers else _upstream_layered_summon_lora_params
return _collect_lora_params_with_adapter(
if is_diffusers:
layered_summon_fn = partial(
_layered_summon_lora_params_diffusers, adapter_name=adapter_name, layer_prefixes=layer_prefixes
)
else:
layered_summon_fn = _upstream_layered_summon_lora_params
lora_params = _collect_lora_params_with_adapter(
module,
layered_summon=layered_summon,
base_sync_done=base_sync_done,
adapter_name=adapter_name,
layered_summon_fn=layered_summon_fn,
)
if not lora_params:
raise RuntimeError(
f"collect_lora_params collected 0 parameters with prefixes={layer_prefixes}. "
"Check ``fsdp_layer_prefixes`` in the model config matches the model's "
"FSDP layer naming (e.g. ``['transformer_blocks.']`` for DiT models)."
)
return lora_params
2 changes: 2 additions & 0 deletions verl_omni/workers/config/diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class DiffusionModelConfig(BaseConfig):

algo: Optional[DiffusionRolloutAlgoConfig] = field(default_factory=DiffusionRolloutAlgoConfig)

fsdp_layer_prefixes: list[str] = field(default_factory=lambda: ["transformer_blocks."])

def __post_init__(self):
import_external_libs(self.external_lib)

Expand Down
64 changes: 58 additions & 6 deletions verl_omni/workers/engine/fsdp/diffusers_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,52 @@ def _init_device_mesh(self):

self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1

def _build_module_from_registry(self, torch_dtype: torch.dtype) -> Optional[torch.nn.Module]:
"""Try loading via ``DiffusionModelBase.build_module()``.

Returns ``None`` if the registry has no custom loader, so the
caller falls back to ``diffusers.AutoModel``.
"""
from verl_omni.pipelines.model_base import DiffusionModelBase

model_cls = DiffusionModelBase.get_class(self.model_config)
module = model_cls.build_module(self.model_config, torch_dtype)
if module is None:
return None

logger.warning(
"Built %s via DiffusionModelBase custom loader; engine-level hooks "
"(attention processors, gradient-checkpointing wrappers, LoRA, "
"dtype upcast) may be partially effective or silently inactive. "
"See the docstring of _build_module_from_registry.",
type(module).__name__,
)

try:
module.to(torch_dtype)
except AttributeError:
raise TypeError(
f"{type(module).__name__} returned by build_module() has no to() method. "
"Custom models must be torch.nn.Module instances."
) from None

if self.model_config.enable_gradient_checkpointing:
try:
module.enable_gradient_checkpointing()
except AttributeError:
raise NotImplementedError(
f"Gradient checkpointing is enabled in config, but {type(module).__name__} "
"does not implement enable_gradient_checkpointing(). "
"Either implement it or set enable_gradient_checkpointing=False."
) from None
Comment thread
zhtmike marked this conversation as resolved.
logger.info(
"Gradient checkpointing enabled on %s via enable_gradient_checkpointing().",
type(module).__name__,
)

module.can_generate = lambda: False
return module

def _build_module(self):
from diffusers import AutoModel
from verl.utils.torch_dtypes import PrecisionType
Expand All @@ -201,6 +247,11 @@ def _build_module(self):

torch_dtype = PrecisionType.to_dtype(torch_dtype)

module = self._build_module_from_registry(torch_dtype)
if module is not None:
return module

# Default path: load via diffusers AutoModel
init_context = get_init_weight_context_manager(use_meta_tensor=True, mesh=self.device_mesh)

with init_context(), warnings.catch_warnings():
Expand Down Expand Up @@ -668,6 +719,7 @@ def get_per_tensor_param(
base_sync_done=base_sync_done,
is_diffusers=True,
adapter_name=adapter_name or "default",
layer_prefixes=self.model_config.fsdp_layer_prefixes,
)
if not base_sync_done:
params = {replace_lora_wrapper(k, peft_config): v for k, v in params.items()}
Expand Down Expand Up @@ -760,16 +812,16 @@ def prepare_model_inputs(self, micro_batch: TensorDict, step: int):
"""
latents = micro_batch["all_latents"]
timesteps = micro_batch["all_timesteps"]
prompt_embeds = micro_batch["prompt_embeds"]
prompt_embeds_mask = micro_batch["prompt_embeds_mask"]
negative_prompt_embeds = micro_batch["negative_prompt_embeds"]
negative_prompt_embeds_mask = micro_batch["negative_prompt_embeds_mask"]
prompt_embeds = micro_batch.get("prompt_embeds", None)
prompt_embeds_mask = micro_batch.get("prompt_embeds_mask", None)
negative_prompt_embeds = micro_batch.get("negative_prompt_embeds", None)
negative_prompt_embeds_mask = micro_batch.get("negative_prompt_embeds_mask", None)
sp_size = self.ulysses_sequence_parallel_size if self.use_ulysses_sp else 1

if prompt_embeds.is_nested:
if isinstance(prompt_embeds, torch.Tensor) and prompt_embeds.is_nested:
prompt_embeds, prompt_embeds_mask = self._unpad_nested_embeds(prompt_embeds, prompt_embeds_mask)

if sp_size > 1:
if isinstance(prompt_embeds, torch.Tensor) and sp_size > 1:
prompt_embeds, prompt_embeds_mask = self._pad_embeds_for_sp(prompt_embeds, prompt_embeds_mask, sp_size)

if isinstance(negative_prompt_embeds, torch.Tensor) and negative_prompt_embeds.is_nested:
Expand Down
Loading