From c143ff0a19883acd7d559068c941909086f276b8 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao Date: Thu, 19 Mar 2026 22:01:29 +0800 Subject: [PATCH 1/8] enable layerwise offload for LTX-2 Signed-off-by: Yuanheng Zhao --- vllm_omni/diffusion/models/ltx2/ltx2_transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py index a1bf7f7809c..744470740ad 100644 --- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py +++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py @@ -1264,6 +1264,7 @@ class LTX2VideoTransformer3DModel(nn.Module): _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTX2VideoTransformerBlock"] + _layerwise_offload_blocks_attr = "transformer_blocks" _sp_plan: dict[str, Any] | None = None @staticmethod From 7aa5f8cf059fd6b251c80e6149df1171fe39be67 Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao Date: Thu, 19 Mar 2026 22:28:55 +0800 Subject: [PATCH 2/8] try to support layerwise offload for DreamID-Omni Signed-off-by: Yuanheng Zhao --- .../diffusion/models/dreamid_omni/fusion.py | 228 ++++++++++-------- 1 file changed, 126 insertions(+), 102 deletions(-) diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py index a534f5a76fa..20ae1d8227a 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py +++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py @@ -15,77 +15,28 @@ logger = init_logger(__name__) -class FusionModel(nn.Module): - def __init__(self, video_config=None, audio_config=None): - super().__init__() - has_video = True - has_audio = True - if video_config is not None: - self.video_model = WanModel(**video_config) - else: - has_video = False - self.video_model = None - logger.warning("No video model is provided!") +class FusedBlock(nn.Module): + """Wrapper pairing a video block and audio block for layerwise offloading. - if audio_config is not None: - self.audio_model = WanModel(**audio_config) - else: - has_audio = False - self.audio_model = None - logger.warning("No audio model is provided!") + Registers both blocks as submodules so their parameters are visible to the offload hooks. + """ - if has_video and has_audio: - assert len(self.video_model.blocks) == len(self.audio_model.blocks) - self.num_blocks = len(self.video_model.blocks) - - self.inject_cross_attention_kv_projections() - self.device = get_local_device() - - self.num_heads = self.video_model.num_heads - self.head_dim = self.video_model.dim // self.video_model.num_heads - self.attn = Attention( - num_heads=self.num_heads, - head_size=self.head_dim, - num_kv_heads=self.num_heads, - softmax_scale=1.0 / (self.head_dim**0.5), - causal=False, - ) - - def inject_cross_attention_kv_projections(self): - for vid_block in self.video_model.blocks: - vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim) - vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim) - vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True) - vid_block.cross_attn.norm_k_fusion = ( - WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity() - ) - - for audio_block in self.audio_model.blocks: - audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim) - audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim) - audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True) - audio_block.cross_attn.norm_k_fusion = ( - WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity() - ) - - def merge_kwargs(self, vid_kwargs, audio_kwargs): - """ - keys in each kwarg: - e - seq_lens - grid_sizes - freqs - context - context_lens - """ - merged_kwargs = {} - for key in vid_kwargs: - merged_kwargs[f"vid_{key}"] = vid_kwargs[key] - for key in audio_kwargs: - merged_kwargs[f"audio_{key}"] = audio_kwargs[key] - return merged_kwargs - - def single_fusion_cross_attention_forward( + def __init__( + self, + vid_block: nn.Module, + audio_block: nn.Module, + attn: Attention, + device: torch.device, + ): + super().__init__() + self.vid_block = vid_block + self.audio_block = audio_block + # `attn` is not registered as submodules, but shared across all FusedBlocks + # and stay GPU-resident as direct children of FusionModel. + self._attn = attn + self._device = device + + def _cross_attention_forward( self, cross_attn_block, src_seq, @@ -104,21 +55,17 @@ def single_fusion_cross_attention_forward( ): b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim if hasattr(cross_attn_block, "k_img"): - ## means is i2v block q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context) else: - ## means is t2v block q, k, v = cross_attn_block.qkv_fn(src_seq, context) k_img = v_img = None - x = self.attn(q, k, v) + x = self._attn(q, k, v) if k_img is not None: - img_x = self.attn(q, k_img, v_img) + img_x = self._attn(q, k_img, v_img) x = x + img_x - # is_vid = src_grid_sizes.shape[1] > 1 - # compute target attention target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq) k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d) v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d) @@ -132,16 +79,14 @@ def single_fusion_cross_attention_forward( freqs_scaling=target_freqs_scaling, ) - target_x = self.attn(q, k_target, v_target) + target_x = self._attn(q, k_target, v_target) x = x + target_x - - x = x.flatten(2) # [B, L/P, C] - + x = x.flatten(2) x = cross_attn_block.o(x) return x - def single_fusion_cross_attention_ffn_forward( + def _cross_attention_ffn_forward( self, attn_block, src_seq, @@ -159,7 +104,7 @@ def single_fusion_cross_attention_ffn_forward( target_ref_lengths=None, target_freqs_scaling=None, ): - src_seq = src_seq + self.single_fusion_cross_attention_forward( + src_seq = src_seq + self._cross_attention_forward( attn_block.cross_attn, attn_block.norm3(src_seq), src_grid_sizes=src_grid_sizes, @@ -176,14 +121,12 @@ def single_fusion_cross_attention_ffn_forward( target_freqs_scaling=target_freqs_scaling, ) y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2)) - with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): src_seq = src_seq + y * src_e[5].squeeze(2) return src_seq - def single_fusion_block_forward( + def forward( self, - vid_block, - audio_block, vid, audio, vid_e, @@ -203,12 +146,15 @@ def single_fusion_block_forward( audio_ref_lengths, audio_freqs_scaling, ): + vid_block = self.vid_block + audio_block = self.audio_block + ## audio modulation assert audio_e.dtype == torch.bfloat16 assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], ( f"{audio_e.shape}, {audio.shape}" ) - with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): audio_e = audio_block.modulation(audio_e).chunk(6, dim=2) assert audio_e[0].dtype == torch.bfloat16 @@ -221,14 +167,14 @@ def single_fusion_block_forward( ref_lengths=audio_ref_lengths, freqs_scaling=audio_freqs_scaling, ) - with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): audio = audio + audio_y * audio_e[2].squeeze(2) ## video modulation assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], ( f"{vid_e.shape}, {vid.shape}" ) - with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): vid_e = vid_block.modulation(vid_e).chunk(6, dim=2) # video self-attention @@ -240,13 +186,13 @@ def single_fusion_block_forward( ref_lengths=vid_ref_lengths, ) - with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): vid = vid + vid_y * vid_e[2].squeeze(2) og_audio = audio # audio cross-attention - audio = self.single_fusion_cross_attention_ffn_forward( + audio = self._cross_attention_ffn_forward( audio_block, audio, audio_grid_sizes, @@ -267,7 +213,7 @@ def single_fusion_block_forward( assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!" # video cross-attention - vid = self.single_fusion_cross_attention_ffn_forward( + vid = self._cross_attention_ffn_forward( vid_block, vid, vid_grid_sizes, @@ -287,6 +233,93 @@ def single_fusion_block_forward( return vid, audio + +class FusionModel(nn.Module): + _layerwise_offload_blocks_attr = "fused_blocks" + + def __init__(self, video_config=None, audio_config=None): + super().__init__() + has_video = True + has_audio = True + if video_config is not None: + self.video_model = WanModel(**video_config) + else: + has_video = False + self.video_model = None + logger.warning("No video model is provided!") + + if audio_config is not None: + self.audio_model = WanModel(**audio_config) + else: + has_audio = False + self.audio_model = None + logger.warning("No audio model is provided!") + + if has_video and has_audio: + assert len(self.video_model.blocks) == len(self.audio_model.blocks) + self.num_blocks = len(self.video_model.blocks) + + self.inject_cross_attention_kv_projections() + self.device = get_local_device() + + self.num_heads = self.video_model.num_heads + self.head_dim = self.video_model.dim // self.video_model.num_heads + self.attn = Attention( + num_heads=self.num_heads, + head_size=self.head_dim, + num_kv_heads=self.num_heads, + softmax_scale=1.0 / (self.head_dim**0.5), + causal=False, + ) + + # NOTE: `attn` is a single shared instance across blocks + if has_video and has_audio: + self.fused_blocks = nn.ModuleList( + [ + FusedBlock( + self.video_model.blocks[i], + self.audio_model.blocks[i], + self.attn, + self.device, + ) + for i in range(self.num_blocks) + ] + ) + + def inject_cross_attention_kv_projections(self): + for vid_block in self.video_model.blocks: + vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim) + vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim) + vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True) + vid_block.cross_attn.norm_k_fusion = ( + WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity() + ) + + for audio_block in self.audio_model.blocks: + audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim) + audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim) + audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True) + audio_block.cross_attn.norm_k_fusion = ( + WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity() + ) + + def merge_kwargs(self, vid_kwargs, audio_kwargs): + """ + keys in each kwarg: + e + seq_lens + grid_sizes + freqs + context + context_lens + """ + merged_kwargs = {} + for key in vid_kwargs: + merged_kwargs[f"vid_{key}"] = vid_kwargs[key] + for key in audio_kwargs: + merged_kwargs[f"audio_{key}"] = audio_kwargs[key] + return merged_kwargs + def forward( self, vid, @@ -316,17 +349,8 @@ def forward( kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs) - for i in range(self.num_blocks): - """ - 1 fusion block refers to 1 audio block with 1 video block. - """ - - vid_block = self.video_model.blocks[i] - audio_block = self.audio_model.blocks[i] - - vid, audio = self.single_fusion_block_forward( - vid_block=vid_block, audio_block=audio_block, vid=vid, audio=audio, **kwargs - ) + for fused_block in self.fused_blocks: + vid, audio = fused_block(vid, audio, **kwargs) vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e) audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e) From 0fb1bc610339203adf9565df4e1f75b995f11b2f Mon Sep 17 00:00:00 2001 From: yuanheng Date: Tue, 14 Apr 2026 13:46:11 +0000 Subject: [PATCH 3/8] upd Signed-off-by: yuanheng --- vllm_omni/diffusion/models/ltx2/ltx2_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py index 744470740ad..95ef919c24e 100644 --- a/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py +++ b/vllm_omni/diffusion/models/ltx2/ltx2_transformer.py @@ -1264,7 +1264,7 @@ class LTX2VideoTransformer3DModel(nn.Module): _supports_gradient_checkpointing = True _skip_layerwise_casting_patterns = ["norm"] _repeated_blocks = ["LTX2VideoTransformerBlock"] - _layerwise_offload_blocks_attr = "transformer_blocks" + _layerwise_offload_blocks_attrs = ["transformer_blocks"] _sp_plan: dict[str, Any] | None = None @staticmethod From c343f32cb996de5692b81fe00cc0d38e5a33de91 Mon Sep 17 00:00:00 2001 From: yuanheng Date: Tue, 14 Apr 2026 16:42:36 +0000 Subject: [PATCH 4/8] remap and impl load_weights Signed-off-by: yuanheng --- .../diffusion/models/dreamid_omni/fusion.py | 22 ++++++++++++ .../dreamid_omni/pipeline_dreamid_omni.py | 34 +++++++++++++------ 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py index 20ae1d8227a..2bddb02cc17 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py +++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py @@ -1,3 +1,5 @@ +import re + import torch import torch.nn as nn from vllm.logger import init_logger @@ -286,6 +288,26 @@ def __init__(self, video_config=None, audio_config=None): ] ) + def load_state_dict(self, state_dict, strict=True, assign=False): + """Remap checkpoints where blocks are stored under + `video_model.blocks.N.*` / `audio_model.blocks.N.*` to the current + `fused_blocks.N.vid_block.*` / `fused_blocks.N.audio_block.*`. + """ + needs_remap = any(re.match(r"^(video_model|audio_model)\.blocks\.\d+\.", k) for k in state_dict) + if needs_remap: + remapped = {} + for k, v in state_dict.items(): + remapped[k] = v + new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k) + if new_k != k: + remapped[new_k] = v + continue + new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", k) + if new_k != k: + remapped[new_k] = v + state_dict = remapped + return super().load_state_dict(state_dict, strict=strict, assign=assign) + def inject_cross_attention_kv_projections(self): for vid_block in self.video_model.blocks: vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim) diff --git a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py index c7ab4662d14..cc932f8c1f8 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py +++ b/vllm_omni/diffusion/models/dreamid_omni/pipeline_dreamid_omni.py @@ -4,6 +4,7 @@ import logging import math import os +from collections.abc import Iterable import torch import torch.distributed @@ -16,6 +17,7 @@ from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.models.interface import SupportAudioInput, SupportImageInput from vllm_omni.diffusion.request import OmniDiffusionRequest @@ -27,7 +29,6 @@ init_mmaudio_vae, init_text_model, init_wan_vae_2_2, - load_fusion_checkpoint, ) from dreamid_omni.utils.rearrange import Rearrange from dreamid_omni.utils.resize import NaResize @@ -122,16 +123,24 @@ def __init__( self.text_model = init_text_model(model, rank=self.device) self.text_encoder = self.text_model.model - # Fusion model - ## load audio/video model config - Fusion_model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG) - - checkpoint_path = self.od_config.tf_model_config.get("fusion", None) - assert checkpoint_path is not None, "fusion checkpoint path is None" - load_fusion_checkpoint(Fusion_model, checkpoint_path=os.path.join(model, checkpoint_path)) - self.model = Fusion_model + # Fusion model — weights are loaded later via load_weights() + self.model = FusionModel(VIDEO_CONFIG, AUDIO_CONFIG) self.transformer = self.model + fusion_path = self.od_config.tf_model_config.get("fusion", None) + assert fusion_path is not None, "fusion checkpoint path is None in transformer config" + fusion_subfolder = os.path.dirname(fusion_path) or None + fusion_filename = os.path.basename(fusion_path) + self.weights_sources = [ + DiffusersPipelineLoader.ComponentSource( + model_or_path=model, + subfolder=fusion_subfolder, + revision=None, + prefix="model.", + allow_patterns_overrides=[fusion_filename], + ) + ] + # Fixed attributes, non-configurable self.audio_latent_channel = AUDIO_CONFIG.get("in_dim") self.video_latent_channel = VIDEO_CONFIG.get("in_dim") @@ -226,8 +235,11 @@ def load_image_latent_ref_ip_video( return ref_vae_latents, ref_audio_lengths - def load_weights(self, weights): - pass + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + prefix = "model." + state_dict = {name[len(prefix) :]: tensor for name, tensor in weights if name.startswith(prefix)} + self.model.load_state_dict(state_dict, strict=True) + return {prefix + k for k in state_dict} def get_scheduler_time_steps(self, sampling_steps, solver_name="unipc", device=0, shift=5.0): torch.manual_seed(4) From c1cf8828ad87c54dac1dc1f6becaf4ab6857a706 Mon Sep 17 00:00:00 2001 From: yuanheng Date: Wed, 15 Apr 2026 15:55:18 +0000 Subject: [PATCH 5/8] upd fusion model modeling Signed-off-by: yuanheng --- .../diffusion/models/dreamid_omni/fusion.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py index 2bddb02cc17..b33e54be742 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py +++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py @@ -237,7 +237,7 @@ def forward( class FusionModel(nn.Module): - _layerwise_offload_blocks_attr = "fused_blocks" + _layerwise_offload_blocks_attrs = ["fused_blocks"] def __init__(self, video_config=None, audio_config=None): super().__init__() @@ -287,6 +287,7 @@ def __init__(self, video_config=None, audio_config=None): for i in range(self.num_blocks) ] ) + self._blocks_detached_from_backbones = False def load_state_dict(self, state_dict, strict=True, assign=False): """Remap checkpoints where blocks are stored under @@ -301,12 +302,20 @@ def load_state_dict(self, state_dict, strict=True, assign=False): new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k) if new_k != k: remapped[new_k] = v + if self._blocks_detached_from_backbones: + remapped.pop(k, None) continue new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", k) if new_k != k: remapped[new_k] = v + if self._blocks_detached_from_backbones: + remapped.pop(k, None) state_dict = remapped - return super().load_state_dict(state_dict, strict=strict, assign=assign) + result = super().load_state_dict(state_dict, strict=strict, assign=assign) + if not self._blocks_detached_from_backbones: + self._detach_blocks_from_backbones() + + return result def inject_cross_attention_kv_projections(self): for vid_block in self.video_model.blocks: @@ -325,6 +334,29 @@ def inject_cross_attention_kv_projections(self): WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity() ) + def _detach_blocks_from_backbones(self) -> None: + """Keep offloadable blocks owned only by a single place. + + NOTE: This is a special workaround to support layerwise offloading. + The model registers the same Wan blocks under both the video/audio + backbones and `fused_blocks` which is a wrapper for unified blocks + walking through. However, layerwise offloading will only consider + `fused_blocks` as offloadable components and will materialize all + other modules onto device, including the same blocks owned by both + `fused_blocks` and `video_model` and `audio_model`. + """ + if self._blocks_detached_from_backbones: + return + + video_blocks = list(self.video_model.blocks) + audio_blocks = list(self.audio_model.blocks) + self.video_model._modules.pop("blocks", None) + self.audio_model._modules.pop("blocks", None) + self.video_model.blocks = tuple(video_blocks) + self.audio_model.blocks = tuple(audio_blocks) + + self._blocks_detached_from_backbones = True + def merge_kwargs(self, vid_kwargs, audio_kwargs): """ keys in each kwarg: From 1d10d0ad38ce77efe5931cc4bcd56661d577529d Mon Sep 17 00:00:00 2001 From: yuanheng Date: Wed, 15 Apr 2026 15:55:41 +0000 Subject: [PATCH 6/8] upd gen script Signed-off-by: yuanheng --- .../offline_inference/x_to_video_audio/x_to_video_audio.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py index 322b184e520..497284ceb96 100644 --- a/examples/offline_inference/x_to_video_audio/x_to_video_audio.py +++ b/examples/offline_inference/x_to_video_audio/x_to_video_audio.py @@ -58,6 +58,11 @@ def parse_args() -> argparse.Namespace: default=False, help="Enable CPU offloading for diffusion models.", ) + parser.add_argument( + "--enable-layerwise-offload", + action="store_true", + help="Enable layerwise (blockwise) offloading on DiT modules.", + ) return parser.parse_args() @@ -126,6 +131,7 @@ def main() -> None: parallel_config=parallel_config, model_type=args.model_type, enable_cpu_offload=args.enable_cpu_offload, + enable_layerwise_offload=args.enable_layerwise_offload, ) start = time.perf_counter() outputs = omni.generate(prompt, sampling_params) From 78f0b6e462c8d1db7c4484adf77982fc4b7f92e5 Mon Sep 17 00:00:00 2001 From: yuanheng Date: Wed, 15 Apr 2026 16:05:44 +0000 Subject: [PATCH 7/8] upd diffusion feat doc Signed-off-by: yuanheng --- docs/user_guide/diffusion_features.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/diffusion_features.md b/docs/user_guide/diffusion_features.md index 45953b85299..2a521db27f2 100644 --- a/docs/user_guide/diffusion_features.md +++ b/docs/user_guide/diffusion_features.md @@ -139,10 +139,10 @@ The following tables show which models support each feature: |-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:| | **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ | | **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ | -| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ | -| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | +| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | **Frame Interpolation Support** From 1b7a36b865a9caae652cc798e055fde788379cee Mon Sep 17 00:00:00 2001 From: Yuanheng Zhao Date: Sat, 18 Apr 2026 11:12:26 +0000 Subject: [PATCH 8/8] upd fusion model Signed-off-by: Yuanheng Zhao --- .../diffusion/models/dreamid_omni/fusion.py | 61 +++++++------------ 1 file changed, 23 insertions(+), 38 deletions(-) diff --git a/vllm_omni/diffusion/models/dreamid_omni/fusion.py b/vllm_omni/diffusion/models/dreamid_omni/fusion.py index b33e54be742..abca4c9474f 100644 --- a/vllm_omni/diffusion/models/dreamid_omni/fusion.py +++ b/vllm_omni/diffusion/models/dreamid_omni/fusion.py @@ -27,19 +27,16 @@ def __init__( self, vid_block: nn.Module, audio_block: nn.Module, - attn: Attention, device: torch.device, ): super().__init__() self.vid_block = vid_block self.audio_block = audio_block - # `attn` is not registered as submodules, but shared across all FusedBlocks - # and stay GPU-resident as direct children of FusionModel. - self._attn = attn - self._device = device + self.device = device def _cross_attention_forward( self, + attn: Attention, cross_attn_block, src_seq, src_grid_sizes, @@ -62,10 +59,10 @@ def _cross_attention_forward( q, k, v = cross_attn_block.qkv_fn(src_seq, context) k_img = v_img = None - x = self._attn(q, k, v) + x = attn(q, k, v) if k_img is not None: - img_x = self._attn(q, k_img, v_img) + img_x = attn(q, k_img, v_img) x = x + img_x target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq) @@ -81,7 +78,7 @@ def _cross_attention_forward( freqs_scaling=target_freqs_scaling, ) - target_x = self._attn(q, k_target, v_target) + target_x = attn(q, k_target, v_target) x = x + target_x x = x.flatten(2) @@ -90,6 +87,7 @@ def _cross_attention_forward( def _cross_attention_ffn_forward( self, + attn: Attention, attn_block, src_seq, src_grid_sizes, @@ -107,6 +105,7 @@ def _cross_attention_ffn_forward( target_freqs_scaling=None, ): src_seq = src_seq + self._cross_attention_forward( + attn, attn_block.cross_attn, attn_block.norm3(src_seq), src_grid_sizes=src_grid_sizes, @@ -123,7 +122,7 @@ def _cross_attention_ffn_forward( target_freqs_scaling=target_freqs_scaling, ) y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2)) - with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): src_seq = src_seq + y * src_e[5].squeeze(2) return src_seq @@ -131,6 +130,7 @@ def forward( self, vid, audio, + attn: Attention, vid_e, vid_seq_lens, vid_grid_sizes, @@ -156,7 +156,7 @@ def forward( assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], ( f"{audio_e.shape}, {audio.shape}" ) - with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): audio_e = audio_block.modulation(audio_e).chunk(6, dim=2) assert audio_e[0].dtype == torch.bfloat16 @@ -169,14 +169,14 @@ def forward( ref_lengths=audio_ref_lengths, freqs_scaling=audio_freqs_scaling, ) - with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): audio = audio + audio_y * audio_e[2].squeeze(2) ## video modulation assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], ( f"{vid_e.shape}, {vid.shape}" ) - with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): vid_e = vid_block.modulation(vid_e).chunk(6, dim=2) # video self-attention @@ -188,13 +188,14 @@ def forward( ref_lengths=vid_ref_lengths, ) - with torch.autocast(device_type=self._device.type, dtype=torch.bfloat16, enabled=True): + with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16, enabled=True): vid = vid + vid_y * vid_e[2].squeeze(2) og_audio = audio # audio cross-attention audio = self._cross_attention_ffn_forward( + attn, audio_block, audio, audio_grid_sizes, @@ -216,6 +217,7 @@ def forward( # video cross-attention vid = self._cross_attention_ffn_forward( + attn, vid_block, vid, vid_grid_sizes, @@ -243,6 +245,7 @@ def __init__(self, video_config=None, audio_config=None): super().__init__() has_video = True has_audio = True + self.device = get_local_device() if video_config is not None: self.video_model = WanModel(**video_config) else: @@ -262,10 +265,10 @@ def __init__(self, video_config=None, audio_config=None): self.num_blocks = len(self.video_model.blocks) self.inject_cross_attention_kv_projections() - self.device = get_local_device() self.num_heads = self.video_model.num_heads self.head_dim = self.video_model.dim // self.video_model.num_heads + # Make a single shared instance to pass in at forward time self.attn = Attention( num_heads=self.num_heads, head_size=self.head_dim, @@ -274,20 +277,17 @@ def __init__(self, video_config=None, audio_config=None): causal=False, ) - # NOTE: `attn` is a single shared instance across blocks if has_video and has_audio: self.fused_blocks = nn.ModuleList( [ FusedBlock( self.video_model.blocks[i], self.audio_model.blocks[i], - self.attn, self.device, ) for i in range(self.num_blocks) ] ) - self._blocks_detached_from_backbones = False def load_state_dict(self, state_dict, strict=True, assign=False): """Remap checkpoints where blocks are stored under @@ -298,24 +298,14 @@ def load_state_dict(self, state_dict, strict=True, assign=False): if needs_remap: remapped = {} for k, v in state_dict.items(): - remapped[k] = v new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k) - if new_k != k: - remapped[new_k] = v - if self._blocks_detached_from_backbones: - remapped.pop(k, None) - continue - new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", k) - if new_k != k: - remapped[new_k] = v - if self._blocks_detached_from_backbones: - remapped.pop(k, None) + new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", new_k) + remapped[new_k] = v state_dict = remapped - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - if not self._blocks_detached_from_backbones: - self._detach_blocks_from_backbones() - return result + self._detach_blocks_from_backbones() + + return super().load_state_dict(state_dict, strict=strict, assign=assign) def inject_cross_attention_kv_projections(self): for vid_block in self.video_model.blocks: @@ -345,9 +335,6 @@ def _detach_blocks_from_backbones(self) -> None: other modules onto device, including the same blocks owned by both `fused_blocks` and `video_model` and `audio_model`. """ - if self._blocks_detached_from_backbones: - return - video_blocks = list(self.video_model.blocks) audio_blocks = list(self.audio_model.blocks) self.video_model._modules.pop("blocks", None) @@ -355,8 +342,6 @@ def _detach_blocks_from_backbones(self) -> None: self.video_model.blocks = tuple(video_blocks) self.audio_model.blocks = tuple(audio_blocks) - self._blocks_detached_from_backbones = True - def merge_kwargs(self, vid_kwargs, audio_kwargs): """ keys in each kwarg: @@ -404,7 +389,7 @@ def forward( kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs) for fused_block in self.fused_blocks: - vid, audio = fused_block(vid, audio, **kwargs) + vid, audio = fused_block(vid, audio, self.attn, **kwargs) vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e) audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e)