Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 170 additions & 7 deletions vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
import vllm._custom_ops as ops
from diffusers.models.activations import get_activation
from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
Expand All @@ -16,6 +18,7 @@
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from vllm_omni.diffusion.attention.layer import Attention
Expand All @@ -24,13 +27,114 @@
logger = logging.getLogger(__name__)


def _patch_cutlass_padded_fp8():
"""Monkey-patch vllm._custom_ops.cutlass_scaled_mm to pad tensors whose
dimensions are not multiples of 16, so the CUTLASS FP8 kernel is used.

OmniGen2 has hidden_size=2520 (2520 % 16 == 8). Without this patch,
vLLM's cutlass_scaled_mm falls back to a Triton scaled_mm kernel for
every FP8 linear layer (QKV, attn output, gate_up_proj, down_proj),
which is dramatically slower than the native CUTLASS FP8 tensor-core
path on H100/H200 GPUs.

Weight tensors (b) are constant across forward passes, so padded
versions are computed once and cached by data_ptr to avoid repeated
allocation and column-major conversion overhead.
"""
_orig_cutlass_scaled_mm = ops.cutlass_scaled_mm
# Cache: data_ptr → (padded_b, padded_bias, padded_scale_b, pad_k, pad_n, orig_n)
_weight_cache: dict[int, tuple] = {}

def _padded_cutlass_scaled_mm(
a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0:
return _orig_cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

# Reshape to 2D (mirrors the original function)
target_shape = (*a.shape[:-1], b.shape[1])
a = a.view(-1, a.shape[-1])
orig_n = b.shape[1]

# Cache the padded weight — it's a model parameter that never changes.
key = b.data_ptr()
if key not in _weight_cache:
Comment thread
zhangj1an marked this conversation as resolved.
pad_k = (16 - b.shape[0] % 16) % 16
pad_n = (16 - orig_n % 16) % 16
b_pad = b
if pad_k > 0:
b_pad = F.pad(b_pad, (0, 0, 0, pad_k))
if pad_n > 0:
b_pad = F.pad(b_pad, (0, pad_n))
# CUTLASS requires b column-major (stride(0)==1).
b_pad = b_pad.t().contiguous().t()

bias_pad = None
if bias is not None and pad_n > 0:
bias_pad = F.pad(bias, (0, pad_n))
Comment thread
zhangj1an marked this conversation as resolved.

scale_b_pad = scale_b
if scale_b.numel() > 1 and pad_n > 0:
scale_b_pad = F.pad(
scale_b.view(-1, scale_b.shape[-1]),
(0, pad_n),
value=1.0,
)

_weight_cache[key] = (
b_pad,
bias_pad,
scale_b_pad,
pad_k,
pad_n,
orig_n,
)

b_pad, bias_pad, scale_b_pad, pad_k, pad_n, orig_n = _weight_cache[key]

# Pad activations on K dimension (cheap — activations are small).
if pad_k > 0:
a = F.pad(a, (0, pad_k)).contiguous()

out = torch.empty((a.shape[0], b_pad.shape[1]), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(
out,
a,
b_pad,
scale_a,
scale_b_pad,
bias_pad if bias is not None else None,
)

if pad_n > 0:
out = out[:, :orig_n]

return out.view(*target_shape)

ops.cutlass_scaled_mm = _padded_cutlass_scaled_mm
logger.info(
"Patched vllm._custom_ops.cutlass_scaled_mm with CUTLASS-padded FP8 "
"variant (avoids slow Triton fallback for non-%%16 dimensions)"
)


_patch_cutlass_padded_fp8()


class OmniGen2Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
num_kv_heads: int,
eps: float = 1e-5,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.dim = dim
Expand All @@ -46,12 +150,26 @@ def __init__(
total_num_kv_heads=num_kv_heads,
disable_tp=True,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.to_qkv",
)

self.norm_q = RMSNorm(self.head_dim, eps=eps)
self.norm_k = RMSNorm(self.head_dim, eps=eps)

self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)])
self.to_out = nn.ModuleList(
[
RowParallelLinear(
dim,
dim,
bias=False,
input_is_parallel=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.to_out.0",
)
]
)
self.attn = Attention(
num_heads=num_heads,
head_size=self.head_dim,
Expand All @@ -78,6 +196,9 @@ def forward(
"""
batch_size = hidden_states.shape[0]

# Contiguous layout for FP8 quantized linear GEMMs (matches FLUX DiT).
hidden_states = hidden_states.contiguous()

# Get Query-Key-Value Pair
qkv, _ = self.to_qkv(hidden_states)

Expand Down Expand Up @@ -121,7 +242,7 @@ def forward(
hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(dtype)

hidden_states = self.to_out[0](hidden_states)
hidden_states = self.to_out[0](hidden_states.contiguous())

return hidden_states

Expand Down Expand Up @@ -233,6 +354,7 @@ def __init__(
embedding_dim: int,
norm_eps: float,
norm_elementwise_affine: bool,
**kwargs,
):
super().__init__()
self.silu = nn.SiLU()
Expand Down Expand Up @@ -325,6 +447,8 @@ def __init__(
inner_dim: int,
multiple_of: int | None = 256,
ffn_dim_multiplier: float | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -338,6 +462,8 @@ def __init__(
[inner_dim, inner_dim],
bias=False,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.act_fn = get_act_and_mul_fn("silu")
self.down_proj = RowParallelLinear(
Expand All @@ -346,6 +472,8 @@ def __init__(
bias=False,
input_is_parallel=True,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)

def forward(self, x):
Expand Down Expand Up @@ -591,6 +719,8 @@ def __init__(
ffn_dim_multiplier: float,
norm_eps: float,
modulation: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
"""Initialize the transformer block."""
super().__init__()
Expand All @@ -602,6 +732,8 @@ def __init__(
num_heads=num_attention_heads,
num_kv_heads=num_kv_heads,
eps=1e-5,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)

# Initialize feed-forward network
Expand All @@ -610,11 +742,19 @@ def __init__(
inner_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward",
)

# Initialize normalization layers
if modulation:
self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True)
self.norm1 = LuminaRMSNormZero(
embedding_dim=dim,
norm_eps=norm_eps,
norm_elementwise_affine=True,
quant_config=quant_config,
prefix=f"{prefix}.norm1",
)
else:
self.norm1 = RMSNorm(dim, eps=norm_eps)

Expand Down Expand Up @@ -713,6 +853,7 @@ def __init__(
axes_lens: tuple[int, int, int] = (1024, 1664, 1664),
text_feat_dim: int = 2048,
timestep_scale: float = 1000.0,
quant_config: QuantizationConfig | None = None,
) -> None:
"""Initialize the OmniGen2 transformer model."""
super().__init__()
Expand Down Expand Up @@ -770,8 +911,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
quant_config=quant_config,
prefix=f"noise_refiner.{i}",
)
for _ in range(num_refiner_layers)
for i in range(num_refiner_layers)
]
)

Expand All @@ -785,8 +928,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
quant_config=quant_config,
prefix=f"ref_image_refiner.{i}",
)
for _ in range(num_refiner_layers)
for i in range(num_refiner_layers)
]
)

Expand All @@ -800,8 +945,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=False,
quant_config=quant_config,
prefix=f"context_refiner.{i}",
)
for _ in range(num_refiner_layers)
for i in range(num_refiner_layers)
]
)

Expand All @@ -816,8 +963,10 @@ def __init__(
ffn_dim_multiplier,
norm_eps,
modulation=True,
quant_config=quant_config,
prefix=f"layers.{i}",
)
for _ in range(num_layers)
for i in range(num_layers)
]
)

Expand Down Expand Up @@ -847,11 +996,25 @@ def img_patch_embed_and_refine(
temb,
):
batch_size = len(hidden_states)
has_ref_tokens = any(ref_img_len > 0 for ref_lens in l_effective_ref_img_len for ref_img_len in ref_lens)
max_combined_img_len = max(
[img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)]
)

hidden_states = self.x_embedder(hidden_states)
if not has_ref_tokens:
# FP8 kernels do not support zero-token GEMM on ref_image_patch_embedder; skip that path only.
# Still run noise_refiner and return the same combined layout as the no-ref case below
# (batch, max_combined_img_len, hidden) — not raw noise tokens alone.
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
combined_img_hidden_states = hidden_states.new_zeros(
batch_size, max_combined_img_len, self.config.hidden_size
)
for i, img_len in enumerate(l_effective_img_len):
combined_img_hidden_states[i, :img_len] = hidden_states[i, :img_len]
return combined_img_hidden_states

ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states)

for i in range(batch_size):
Expand Down
7 changes: 4 additions & 3 deletions vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,10 @@ def __init__(
)

transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel)
self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs)
self.transformer = OmniGen2Transformer2DModel(
**transformer_kwargs,
quant_config=od_config.quantization_config,
)
self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model, subfolder="mllm", local_files_only=local_files_only
).to(self.device)
Expand Down Expand Up @@ -1253,8 +1256,6 @@ def predict(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

batch_size, num_channels_latents, height, width = latents.shape

optional_kwargs = {}
if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()):
optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states
Expand Down
Loading