diff --git a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py index 9ff681a3c0..3f03563a1c 100644 --- a/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py +++ b/vllm_omni/diffusion/models/omnigen2/omnigen2_transformer.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +import torch.nn.functional as F +import vllm._custom_ops as ops from diffusers.models.activations import get_activation from diffusers.models.embeddings import Timesteps, get_1d_rotary_pos_embed from diffusers.models.modeling_outputs import Transformer2DModelOutput @@ -16,6 +18,7 @@ QKVParallelLinear, RowParallelLinear, ) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm_omni.diffusion.attention.layer import Attention @@ -24,6 +27,105 @@ logger = logging.getLogger(__name__) +def _patch_cutlass_padded_fp8(): + """Monkey-patch vllm._custom_ops.cutlass_scaled_mm to pad tensors whose + dimensions are not multiples of 16, so the CUTLASS FP8 kernel is used. + + OmniGen2 has hidden_size=2520 (2520 % 16 == 8). Without this patch, + vLLM's cutlass_scaled_mm falls back to a Triton scaled_mm kernel for + every FP8 linear layer (QKV, attn output, gate_up_proj, down_proj), + which is dramatically slower than the native CUTLASS FP8 tensor-core + path on H100/H200 GPUs. + + Weight tensors (b) are constant across forward passes, so padded + versions are computed once and cached by data_ptr to avoid repeated + allocation and column-major conversion overhead. + """ + _orig_cutlass_scaled_mm = ops.cutlass_scaled_mm + # Cache: data_ptr → (padded_b, padded_bias, padded_scale_b, pad_k, pad_n, orig_n) + _weight_cache: dict[int, tuple] = {} + + def _padded_cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + if b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0: + return _orig_cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + # Reshape to 2D (mirrors the original function) + target_shape = (*a.shape[:-1], b.shape[1]) + a = a.view(-1, a.shape[-1]) + orig_n = b.shape[1] + + # Cache the padded weight — it's a model parameter that never changes. + key = b.data_ptr() + if key not in _weight_cache: + pad_k = (16 - b.shape[0] % 16) % 16 + pad_n = (16 - orig_n % 16) % 16 + b_pad = b + if pad_k > 0: + b_pad = F.pad(b_pad, (0, 0, 0, pad_k)) + if pad_n > 0: + b_pad = F.pad(b_pad, (0, pad_n)) + # CUTLASS requires b column-major (stride(0)==1). + b_pad = b_pad.t().contiguous().t() + + bias_pad = None + if bias is not None and pad_n > 0: + bias_pad = F.pad(bias, (0, pad_n)) + + scale_b_pad = scale_b + if scale_b.numel() > 1 and pad_n > 0: + scale_b_pad = F.pad( + scale_b.view(-1, scale_b.shape[-1]), + (0, pad_n), + value=1.0, + ) + + _weight_cache[key] = ( + b_pad, + bias_pad, + scale_b_pad, + pad_k, + pad_n, + orig_n, + ) + + b_pad, bias_pad, scale_b_pad, pad_k, pad_n, orig_n = _weight_cache[key] + + # Pad activations on K dimension (cheap — activations are small). + if pad_k > 0: + a = F.pad(a, (0, pad_k)).contiguous() + + out = torch.empty((a.shape[0], b_pad.shape[1]), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_mm( + out, + a, + b_pad, + scale_a, + scale_b_pad, + bias_pad if bias is not None else None, + ) + + if pad_n > 0: + out = out[:, :orig_n] + + return out.view(*target_shape) + + ops.cutlass_scaled_mm = _padded_cutlass_scaled_mm + logger.info( + "Patched vllm._custom_ops.cutlass_scaled_mm with CUTLASS-padded FP8 " + "variant (avoids slow Triton fallback for non-%%16 dimensions)" + ) + + +_patch_cutlass_padded_fp8() + + class OmniGen2Attention(nn.Module): def __init__( self, @@ -31,6 +133,8 @@ def __init__( num_heads: int, num_kv_heads: int, eps: float = 1e-5, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() self.dim = dim @@ -46,12 +150,26 @@ def __init__( total_num_kv_heads=num_kv_heads, disable_tp=True, bias=False, + quant_config=quant_config, + prefix=f"{prefix}.to_qkv", ) self.norm_q = RMSNorm(self.head_dim, eps=eps) self.norm_k = RMSNorm(self.head_dim, eps=eps) - self.to_out = nn.ModuleList([nn.Linear(dim, dim, bias=False)]) + self.to_out = nn.ModuleList( + [ + RowParallelLinear( + dim, + dim, + bias=False, + input_is_parallel=False, + quant_config=quant_config, + return_bias=False, + prefix=f"{prefix}.to_out.0", + ) + ] + ) self.attn = Attention( num_heads=num_heads, head_size=self.head_dim, @@ -78,6 +196,9 @@ def forward( """ batch_size = hidden_states.shape[0] + # Contiguous layout for FP8 quantized linear GEMMs (matches FLUX DiT). + hidden_states = hidden_states.contiguous() + # Get Query-Key-Value Pair qkv, _ = self.to_qkv(hidden_states) @@ -121,7 +242,7 @@ def forward( hidden_states = hidden_states.reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(dtype) - hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[0](hidden_states.contiguous()) return hidden_states @@ -233,6 +354,7 @@ def __init__( embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool, + **kwargs, ): super().__init__() self.silu = nn.SiLU() @@ -325,6 +447,8 @@ def __init__( inner_dim: int, multiple_of: int | None = 256, ffn_dim_multiplier: float | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ): super().__init__() @@ -338,6 +462,8 @@ def __init__( [inner_dim, inner_dim], bias=False, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", ) self.act_fn = get_act_and_mul_fn("silu") self.down_proj = RowParallelLinear( @@ -346,6 +472,8 @@ def __init__( bias=False, input_is_parallel=True, return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", ) def forward(self, x): @@ -591,6 +719,8 @@ def __init__( ffn_dim_multiplier: float, norm_eps: float, modulation: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", ) -> None: """Initialize the transformer block.""" super().__init__() @@ -602,6 +732,8 @@ def __init__( num_heads=num_attention_heads, num_kv_heads=num_kv_heads, eps=1e-5, + quant_config=quant_config, + prefix=f"{prefix}.attn", ) # Initialize feed-forward network @@ -610,11 +742,19 @@ def __init__( inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", ) # Initialize normalization layers if modulation: - self.norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, norm_elementwise_affine=True) + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + norm_elementwise_affine=True, + quant_config=quant_config, + prefix=f"{prefix}.norm1", + ) else: self.norm1 = RMSNorm(dim, eps=norm_eps) @@ -713,6 +853,7 @@ def __init__( axes_lens: tuple[int, int, int] = (1024, 1664, 1664), text_feat_dim: int = 2048, timestep_scale: float = 1000.0, + quant_config: QuantizationConfig | None = None, ) -> None: """Initialize the OmniGen2 transformer model.""" super().__init__() @@ -770,8 +911,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=True, + quant_config=quant_config, + prefix=f"noise_refiner.{i}", ) - for _ in range(num_refiner_layers) + for i in range(num_refiner_layers) ] ) @@ -785,8 +928,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=True, + quant_config=quant_config, + prefix=f"ref_image_refiner.{i}", ) - for _ in range(num_refiner_layers) + for i in range(num_refiner_layers) ] ) @@ -800,8 +945,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=False, + quant_config=quant_config, + prefix=f"context_refiner.{i}", ) - for _ in range(num_refiner_layers) + for i in range(num_refiner_layers) ] ) @@ -816,8 +963,10 @@ def __init__( ffn_dim_multiplier, norm_eps, modulation=True, + quant_config=quant_config, + prefix=f"layers.{i}", ) - for _ in range(num_layers) + for i in range(num_layers) ] ) @@ -847,11 +996,25 @@ def img_patch_embed_and_refine( temb, ): batch_size = len(hidden_states) + has_ref_tokens = any(ref_img_len > 0 for ref_lens in l_effective_ref_img_len for ref_img_len in ref_lens) max_combined_img_len = max( [img_len + sum(ref_img_len) for img_len, ref_img_len in zip(l_effective_img_len, l_effective_ref_img_len)] ) hidden_states = self.x_embedder(hidden_states) + if not has_ref_tokens: + # FP8 kernels do not support zero-token GEMM on ref_image_patch_embedder; skip that path only. + # Still run noise_refiner and return the same combined layout as the no-ref case below + # (batch, max_combined_img_len, hidden) — not raw noise tokens alone. + for layer in self.noise_refiner: + hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb) + combined_img_hidden_states = hidden_states.new_zeros( + batch_size, max_combined_img_len, self.config.hidden_size + ) + for i, img_len in enumerate(l_effective_img_len): + combined_img_hidden_states[i, :img_len] = hidden_states[i, :img_len] + return combined_img_hidden_states + ref_image_hidden_states = self.ref_image_patch_embedder(ref_image_hidden_states) for i in range(batch_size): diff --git a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py index e8e307b878..04720c932f 100644 --- a/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py +++ b/vllm_omni/diffusion/models/omnigen2/pipeline_omnigen2.py @@ -676,7 +676,10 @@ def __init__( ) transformer_kwargs = get_transformer_config_kwargs(od_config.tf_model_config, OmniGen2Transformer2DModel) - self.transformer = OmniGen2Transformer2DModel(**transformer_kwargs) + self.transformer = OmniGen2Transformer2DModel( + **transformer_kwargs, + quant_config=od_config.quantization_config, + ) self.mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( model, subfolder="mllm", local_files_only=local_files_only ).to(self.device) @@ -1253,8 +1256,6 @@ def predict( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - batch_size, num_channels_latents, height, width = latents.shape - optional_kwargs = {} if "ref_image_hidden_states" in set(inspect.signature(self.transformer.forward).parameters.keys()): optional_kwargs["ref_image_hidden_states"] = ref_image_hidden_states