diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index a4ae3118a79..20e2b9fea83 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -52,10 +52,14 @@ def apply_rotary_emb_wan( x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) cos = freqs_cos[..., 0::2] sin = freqs_sin[..., 1::2] - out = torch.empty_like(hidden_states) - out[..., 0::2] = x1 * cos - x2 * sin - out[..., 1::2] = x1 * sin + x2 * cos - return out.type_as(hidden_states) + rotated = torch.stack( + ( + x1 * cos - x2 * sin, + x1 * sin + x2 * cos, + ), + dim=-1, + ) + return rotated.flatten(-2, -1).to(hidden_states.dtype) class DistributedRMSNorm(nn.Module):