Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/user_guide/diffusion_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ The following tables show which models support each feature:
|-------|:----------:|:-----------:|:---------------------:|:--------------:|:-----------------:|:------:|:------------------------:|:--------------------:|:--------------:|:----------------:|
| **Wan2.2** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (encode/decode) | ❌ | ❌ |
| **Wan2.1-VACE** | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ❌ | ❌ |
| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | | ❌ | ❌ | ❌ |
| **LTX-2** | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | | ❌ | ❌ | ❌ |
| **Helios** | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| **HunyuanVideo-1.5 T2V I2V** | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ (decode) | ✅ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | ❌ | ❌ | ❌ |
| **DreamID-Omni** | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | ❌ | ❌ | ❌ |

**Frame Interpolation Support**

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ def parse_args() -> argparse.Namespace:
default=False,
help="Enable CPU offloading for diffusion models.",
)
parser.add_argument(
"--enable-layerwise-offload",
action="store_true",
help="Enable layerwise (blockwise) offloading on DiT modules.",
)
return parser.parse_args()


Expand Down Expand Up @@ -126,6 +131,7 @@ def main() -> None:
parallel_config=parallel_config,
model_type=args.model_type,
enable_cpu_offload=args.enable_cpu_offload,
enable_layerwise_offload=args.enable_layerwise_offload,
)
start = time.perf_counter()
outputs = omni.generate(prompt, sampling_params)
Expand Down
255 changes: 159 additions & 96 deletions vllm_omni/diffusion/models/dreamid_omni/fusion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import torch
import torch.nn as nn
from vllm.logger import init_logger
Expand All @@ -15,78 +17,26 @@
logger = init_logger(__name__)


class FusionModel(nn.Module):
def __init__(self, video_config=None, audio_config=None):
super().__init__()
has_video = True
has_audio = True
if video_config is not None:
self.video_model = WanModel(**video_config)
else:
has_video = False
self.video_model = None
logger.warning("No video model is provided!")

if audio_config is not None:
self.audio_model = WanModel(**audio_config)
else:
has_audio = False
self.audio_model = None
logger.warning("No audio model is provided!")

if has_video and has_audio:
assert len(self.video_model.blocks) == len(self.audio_model.blocks)
self.num_blocks = len(self.video_model.blocks)

self.inject_cross_attention_kv_projections()
self.device = get_local_device()

self.num_heads = self.video_model.num_heads
self.head_dim = self.video_model.dim // self.video_model.num_heads
self.attn = Attention(
num_heads=self.num_heads,
head_size=self.head_dim,
num_kv_heads=self.num_heads,
softmax_scale=1.0 / (self.head_dim**0.5),
causal=False,
)

def inject_cross_attention_kv_projections(self):
for vid_block in self.video_model.blocks:
vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
vid_block.cross_attn.norm_k_fusion = (
WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
)
class FusedBlock(nn.Module):
"""Wrapper pairing a video block and audio block for layerwise offloading.

for audio_block in self.audio_model.blocks:
audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
audio_block.cross_attn.norm_k_fusion = (
WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
)
Registers both blocks as submodules so their parameters are visible to the offload hooks.
"""

def merge_kwargs(self, vid_kwargs, audio_kwargs):
"""
keys in each kwarg:
e
seq_lens
grid_sizes
freqs
context
context_lens
"""
merged_kwargs = {}
for key in vid_kwargs:
merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
for key in audio_kwargs:
merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
return merged_kwargs
def __init__(
self,
vid_block: nn.Module,
audio_block: nn.Module,
device: torch.device,
):
super().__init__()
self.vid_block = vid_block
self.audio_block = audio_block
self.device = device

def single_fusion_cross_attention_forward(
def _cross_attention_forward(
self,
attn: Attention,
cross_attn_block,
src_seq,
src_grid_sizes,
Expand All @@ -104,21 +54,17 @@ def single_fusion_cross_attention_forward(
):
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
if hasattr(cross_attn_block, "k_img"):
## means is i2v block
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
else:
## means is t2v block
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
k_img = v_img = None

x = self.attn(q, k, v)
x = attn(q, k, v)

if k_img is not None:
img_x = self.attn(q, k_img, v_img)
img_x = attn(q, k_img, v_img)
x = x + img_x

# is_vid = src_grid_sizes.shape[1] > 1
# compute target attention
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
Expand All @@ -132,17 +78,16 @@ def single_fusion_cross_attention_forward(
freqs_scaling=target_freqs_scaling,
)

target_x = self.attn(q, k_target, v_target)
target_x = attn(q, k_target, v_target)

x = x + target_x

x = x.flatten(2) # [B, L/P, C]

x = x.flatten(2)
x = cross_attn_block.o(x)
return x

def single_fusion_cross_attention_ffn_forward(
def _cross_attention_ffn_forward(
self,
attn: Attention,
attn_block,
src_seq,
src_grid_sizes,
Expand All @@ -159,7 +104,8 @@ def single_fusion_cross_attention_ffn_forward(
target_ref_lengths=None,
target_freqs_scaling=None,
):
src_seq = src_seq + self.single_fusion_cross_attention_forward(
src_seq = src_seq + self._cross_attention_forward(
attn,
attn_block.cross_attn,
attn_block.norm3(src_seq),
src_grid_sizes=src_grid_sizes,
Expand All @@ -180,12 +126,11 @@ def single_fusion_cross_attention_ffn_forward(
src_seq = src_seq + y * src_e[5].squeeze(2)
return src_seq

def single_fusion_block_forward(
def forward(
self,
vid_block,
audio_block,
vid,
audio,
attn: Attention,
vid_e,
vid_seq_lens,
vid_grid_sizes,
Expand All @@ -203,6 +148,9 @@ def single_fusion_block_forward(
audio_ref_lengths,
audio_freqs_scaling,
):
vid_block = self.vid_block
audio_block = self.audio_block

## audio modulation
assert audio_e.dtype == torch.bfloat16
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], (
Expand Down Expand Up @@ -246,7 +194,8 @@ def single_fusion_block_forward(
og_audio = audio

# audio cross-attention
audio = self.single_fusion_cross_attention_ffn_forward(
audio = self._cross_attention_ffn_forward(
attn,
audio_block,
audio,
audio_grid_sizes,
Expand All @@ -267,7 +216,8 @@ def single_fusion_block_forward(
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"

# video cross-attention
vid = self.single_fusion_cross_attention_ffn_forward(
vid = self._cross_attention_ffn_forward(
attn,
vid_block,
vid,
vid_grid_sizes,
Expand All @@ -287,6 +237,128 @@ def single_fusion_block_forward(

return vid, audio


class FusionModel(nn.Module):
_layerwise_offload_blocks_attrs = ["fused_blocks"]

def __init__(self, video_config=None, audio_config=None):
super().__init__()
has_video = True
has_audio = True
self.device = get_local_device()
if video_config is not None:
self.video_model = WanModel(**video_config)
else:
has_video = False
self.video_model = None
logger.warning("No video model is provided!")

if audio_config is not None:
self.audio_model = WanModel(**audio_config)
else:
has_audio = False
self.audio_model = None
logger.warning("No audio model is provided!")

if has_video and has_audio:
assert len(self.video_model.blocks) == len(self.audio_model.blocks)
self.num_blocks = len(self.video_model.blocks)

self.inject_cross_attention_kv_projections()

self.num_heads = self.video_model.num_heads
self.head_dim = self.video_model.dim // self.video_model.num_heads
# Make a single shared instance to pass in at forward time
self.attn = Attention(
num_heads=self.num_heads,
head_size=self.head_dim,
num_kv_heads=self.num_heads,
softmax_scale=1.0 / (self.head_dim**0.5),
causal=False,
)

if has_video and has_audio:
self.fused_blocks = nn.ModuleList(
[
FusedBlock(
self.video_model.blocks[i],
self.audio_model.blocks[i],
self.device,
)
for i in range(self.num_blocks)
]
)

def load_state_dict(self, state_dict, strict=True, assign=False):
"""Remap checkpoints where blocks are stored under
`video_model.blocks.N.*` / `audio_model.blocks.N.*` to the current
`fused_blocks.N.vid_block.*` / `fused_blocks.N.audio_block.*`.
"""
needs_remap = any(re.match(r"^(video_model|audio_model)\.blocks\.\d+\.", k) for k in state_dict)
if needs_remap:
remapped = {}
for k, v in state_dict.items():
new_k = re.sub(r"^video_model\.blocks\.(\d+)\.", r"fused_blocks.\1.vid_block.", k)
new_k = re.sub(r"^audio_model\.blocks\.(\d+)\.", r"fused_blocks.\1.audio_block.", new_k)
remapped[new_k] = v
state_dict = remapped

self._detach_blocks_from_backbones()

return super().load_state_dict(state_dict, strict=strict, assign=assign)

def inject_cross_attention_kv_projections(self):
for vid_block in self.video_model.blocks:
vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
vid_block.cross_attn.norm_k_fusion = (
WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
)

for audio_block in self.audio_model.blocks:
audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
audio_block.cross_attn.norm_k_fusion = (
WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
)

def _detach_blocks_from_backbones(self) -> None:
"""Keep offloadable blocks owned only by a single place.

NOTE: This is a special workaround to support layerwise offloading.
The model registers the same Wan blocks under both the video/audio
backbones and `fused_blocks` which is a wrapper for unified blocks
walking through. However, layerwise offloading will only consider
`fused_blocks` as offloadable components and will materialize all
other modules onto device, including the same blocks owned by both
`fused_blocks` and `video_model` and `audio_model`.
"""
video_blocks = list(self.video_model.blocks)
audio_blocks = list(self.audio_model.blocks)
self.video_model._modules.pop("blocks", None)
self.audio_model._modules.pop("blocks", None)
self.video_model.blocks = tuple(video_blocks)
self.audio_model.blocks = tuple(audio_blocks)

def merge_kwargs(self, vid_kwargs, audio_kwargs):
"""
keys in each kwarg:
e
seq_lens
grid_sizes
freqs
context
context_lens
"""
merged_kwargs = {}
for key in vid_kwargs:
merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
for key in audio_kwargs:
merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
return merged_kwargs

def forward(
self,
vid,
Expand Down Expand Up @@ -316,17 +388,8 @@ def forward(

kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)

for i in range(self.num_blocks):
"""
1 fusion block refers to 1 audio block with 1 video block.
"""

vid_block = self.video_model.blocks[i]
audio_block = self.audio_model.blocks[i]

vid, audio = self.single_fusion_block_forward(
vid_block=vid_block, audio_block=audio_block, vid=vid, audio=audio, **kwargs
)
for fused_block in self.fused_blocks:
vid, audio = fused_block(vid, audio, self.attn, **kwargs)

vid = self.video_model.post_transformer_block_out(vid, vid_kwargs["grid_sizes"], vid_e)
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs["grid_sizes"], audio_e)
Expand Down
Loading
Loading