Skip to content
Merged
22 changes: 22 additions & 0 deletions vllm_omni/diffusion/diffusion_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ def __init__(self, od_config: OmniDiffusionConfig):
self._processes: list[mp.Process] = []
self._closed = False
self._make_client()
try:
self._dummy_run()
except Exception as e:
logger.error(f"Dummy run failed: {e}")
self.close()
raise e

def step(self, requests: list[OmniDiffusionRequest]):
try:
Expand Down Expand Up @@ -151,6 +157,22 @@ def _launch_workers(self, broadcast_handle):
def add_req_and_wait_for_response(self, requests: list[OmniDiffusionRequest]):
return scheduler.add_req(requests)

def _dummy_run(self):
"""A dummy run to warm up the model."""
prompt = "dummy run"
num_inference_steps = 1
height = 1024
Comment thread
ZJY0516 marked this conversation as resolved.
width = 1024
req = OmniDiffusionRequest(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
num_outputs_per_prompt=1,
)
logger.info("dummy run to warm up the model")
self.add_req_and_wait_for_response([req])

def close(self, *, timeout_s: float = 30.0) -> None:
if self._closed:
return
Expand Down
38 changes: 38 additions & 0 deletions vllm_omni/diffusion/layers/custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import Callable
from typing import Any

import torch.nn as nn

from vllm_omni.utils.platform_utils import detect_device_type


class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""

def __init__(self) -> None:
super().__init__()
self.is_cuda = detect_device_type() == "cuda"
self._forward_method = self.dispatch_forward()

def dispatch_forward(self) -> Callable:
if self.is_cuda:
return self.forward_cuda
else:
return self.forward_native

def forward(self, *args, **kwargs) -> Any:
return self._forward_method(*args, **kwargs)

def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
80 changes: 80 additions & 0 deletions vllm_omni/diffusion/layers/rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from einops import rearrange, repeat

from vllm_omni.diffusion.layers.custom_op import CustomOp


def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)


def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)


class RotaryEmbedding(CustomOp):
"""
rotary positional embedding.
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
"""

def __init__(
self,
is_neox_style: bool = False,
) -> None:
super().__init__()
self.is_neox_style = is_neox_style
self.interleaved = not is_neox_style

def forward_cuda(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb

if cos.dim() == 3:
# (B, S, D/2) -> (S, D/2)
cos = cos[0]
sin = sin[0]

return apply_rotary_emb(
x,
cos,
sin,
interleaved=self.interleaved,
)

def forward_native(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
return apply_rotary_emb_torch(
x,
cos,
sin,
interleaved=self.interleaved,
)
81 changes: 16 additions & 65 deletions vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,58 +26,11 @@
get_sequence_parallel_world_size,
get_sp_group,
)
from vllm_omni.diffusion.layers.rope import RotaryEmbedding

logger = init_logger(__name__)


def apply_rotary_emb_qwen(
x: torch.Tensor,
freqs_cis: torch.Tensor | tuple[torch.Tensor],
use_real: bool = True,
use_real_unbind_dim: int = -1,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.

Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)

Returns:
tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)

return out
else:
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(1)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)

return x_out.type_as(x)


class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
Expand Down Expand Up @@ -277,13 +230,14 @@ def __init__(
softmax_scale=1.0 / (self.head_dim**0.5),
causal=False,
)
self.rope = RotaryEmbedding(is_neox_style=False)

def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
**cross_attention_kwargs,
vid_freqs: torch.Tensor,
txt_freqs: torch.Tensor,
):
seq_len_txt = encoder_hidden_states.shape[1]

Expand Down Expand Up @@ -311,12 +265,14 @@ def forward(
txt_key = self.norm_added_k(txt_key)

# Apply RoPE
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
img_cos = vid_freqs.real.to(img_query.dtype)
img_sin = vid_freqs.imag.to(img_query.dtype)
txt_cos = txt_freqs.real.to(txt_query.dtype)
txt_sin = txt_freqs.imag.to(txt_query.dtype)
img_query = self.rope(img_query, img_cos, img_sin)
img_key = self.rope(img_key, img_cos, img_sin)
txt_query = self.rope(txt_query, txt_cos, txt_sin)
txt_key = self.rope(txt_key, txt_cos, txt_sin)

# Concatenate for joint attention
# Order: [text, image]
Expand Down Expand Up @@ -347,6 +303,7 @@ def forward(
return img_attn_output, txt_attn_output


@torch.compile(dynamic=True)
class QwenImageTransformerBlock(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -433,7 +390,7 @@ def forward(
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor],
joint_attention_kwargs: dict[str, Any] | None = None,
modulate_index: list[int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -463,13 +420,11 @@ def forward(
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
vid_freqs=image_rotary_emb[0],
txt_freqs=image_rotary_emb[1],
)

# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
Expand Down Expand Up @@ -526,10 +481,6 @@ class QwenImageTransformer2DModel(nn.Module):
The dimensions to use for the rotary positional embeddings.
"""

# _supports_gradient_checkpointing = True
# _no_split_modules = ["QwenImageTransformerBlock"]
# _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
# _repeated_blocks = ["QwenImageTransformerBlock"]
def __init__(
self,
od_config: OmniDiffusionConfig,
Expand Down
19 changes: 8 additions & 11 deletions vllm_omni/diffusion/models/z_image/z_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.layer import Attention
from vllm_omni.diffusion.layers.rope import RotaryEmbedding

ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
Expand Down Expand Up @@ -114,12 +115,13 @@ def __init__(
softmax_scale=1.0 / (self.head_dim**0.5),
causal=False,
)
self.rope = RotaryEmbedding(is_neox_style=False)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
freqs_cis: tuple[torch.Tensor, torch.Tensor] | None = None,
freqs_cis: torch.Tensor,
):
qkv, _ = self.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)
Expand All @@ -131,16 +133,10 @@ def forward(
query = self.norm_q(query)
key = self.norm_k(key)

def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)

if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)

cos = freqs_cis.real.squeeze(0).to(query.dtype)
sin = freqs_cis.imag.squeeze(0).to(query.dtype)
query = self.rope(query, cos, sin)
key = self.rope(key, cos, sin)
# Cast to correct dtype
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
Expand Down Expand Up @@ -189,6 +185,7 @@ def forward(self, x):
return self.w2(self.act(self.w13(x)))


@torch.compile(dynamic=True)
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
Expand Down