From fc5d50fd41d4ca3db77a993e200b2bc8ef553db5 Mon Sep 17 00:00:00 2001 From: lcukyfuture Date: Thu, 9 Apr 2026 15:52:54 +0800 Subject: [PATCH 1/6] [Feat] FP8 quantization support for LongCat-Image and LongCat-Image-Edit Add FP8 quantization support to LongCat-Image and LongCat-Image-Edit pipelines, following the unified quantization framework introduced in #1764. Changes: - Replace plain `nn.Linear` layers in `LongCatImageTransformer2DModel` with quantization-aware vLLM linear layers (`ReplicatedLinear`, `QKVParallelLinear`, `RowParallelLinear`, `ColumnParallelLinear`) and propagate `quant_config` through `FeedForward`, `LongCatImageAttention`, `LongCatImageTransformerBlock`, and `LongCatImageSingleTransformerBlock` - Pass `quant_config=od_config.quantization_config` to the transformer in both `LongCatImagePipeline` and `LongCatImageEditPipeline` - Fix `load_weights` in both pipelines to include VAE and text encoder parameters in the returned loaded-weights set - Fix `TypeError`: `LongCatImageSingleTransformerBlock.__init__` was receiving an unsupported `prefix` keyword argument, causing a crash on startup with any quantization config Signed-off-by: lcukyfuture --- .../longcat_image_transformer.py | 93 +++++++++++++++---- .../longcat_image/pipeline_longcat_image.py | 7 +- .../pipeline_longcat_image_edit.py | 10 +- 3 files changed, 90 insertions(+), 20 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index 8d8e523d60e..f4c5a180be2 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any import torch import torch.nn as nn @@ -13,9 +13,19 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ColumnParallelLinear, QKVParallelLinear, RowParallelLinear +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + ) + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.layer import Attention from vllm_omni.diffusion.data import DiffusionParallelConfig, OmniDiffusionConfig @@ -30,14 +40,14 @@ class FeedForward(nn.Module): - def __init__(self, dim: int, dim_out: int | None = None, mult: int = 4, bias: bool = True): + def __init__(self, dim: int, dim_out: int | None = None, mult: int = 4, bias: bool = True, quant_config: "QuantizationConfig | None" = None, prefix: str = ""): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - self.w_in = ColumnParallelLinear(dim, inner_dim, bias=bias, return_bias=False) + self.w_in = ColumnParallelLinear(dim, inner_dim, bias=bias, return_bias=False, quant_config=quant_config, prefix=f"{prefix}.w_in") self.act = get_act_fn("gelu_pytorch_tanh") - self.w_out = RowParallelLinear(inner_dim, dim_out, bias=bias, return_bias=False) + self.w_out = RowParallelLinear(inner_dim, dim_out, bias=bias, return_bias=False, quant_config=quant_config, prefix=f"{prefix}.w_out") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.w_in(hidden_states) @@ -62,6 +72,7 @@ def __init__( out_dim: int = None, context_pre_only: bool | None = None, pre_only: bool = False, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.parallel_config = parallel_config @@ -85,10 +96,12 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=bias, + quant_config=quant_config, + prefix="to_qkv", ) if not self.pre_only: - self.to_out = RowParallelLinear(self.inner_dim, self.out_dim, bias=out_bias) + self.to_out = RowParallelLinear(self.inner_dim, self.out_dim, bias=out_bias, quant_config=quant_config, prefix="to_out") if self.added_kv_proj_dim is not None: self.norm_added_q = RMSNorm(dim_head, eps=eps) @@ -99,9 +112,11 @@ def __init__( head_size=self.head_dim, total_num_heads=self.heads, bias=added_proj_bias, + quant_config=quant_config, + prefix="add_kv_proj", ) - self.to_add_out = RowParallelLinear(self.inner_dim, query_dim, bias=out_bias) + self.to_add_out = RowParallelLinear(self.inner_dim, query_dim, bias=out_bias, quant_config=quant_config, prefix="to_add_out") self.attn = Attention( num_heads=heads, @@ -182,6 +197,8 @@ def forward( - Standard concatenation of text + image Q/K/V - Regular attention over the full sequence """ + # Ensure contiguous for FP8 quantized linear layers + hidden_states = hidden_states.contiguous() qkv, _ = self.to_qkv(hidden_states) q_size = self.to_qkv.num_heads * self.head_dim @@ -196,6 +213,7 @@ def forward( key = self.norm_k(key) if self.added_kv_proj_dim is not None: + encoder_hidden_states = encoder_hidden_states.contiguous() encoder_qkv, _ = self.add_kv_proj(encoder_hidden_states) q_size = self.add_kv_proj.num_heads * self.head_dim kv_size = self.add_kv_proj.num_kv_heads * self.head_dim @@ -293,8 +311,9 @@ def forward( encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 ) - hidden_states, _ = self.to_out(hidden_states) - encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states) + # Contiguous for FP8 quantization in RowParallelLinear + hidden_states, _ = self.to_out(hidden_states.contiguous()) + encoder_hidden_states, _ = self.to_add_out(encoder_hidden_states.contiguous()) return hidden_states, encoder_hidden_states else: @@ -313,6 +332,7 @@ def __init__( attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() @@ -330,13 +350,14 @@ def __init__( context_pre_only=False, bias=True, eps=eps, + quant_config=quant_config, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim) + self.ff = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="ff") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim) + self.ff_context = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="ff_context") def forward( self, @@ -508,14 +529,29 @@ def __init__( num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) - self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) + self.proj_mlp = ReplicatedLinear( + dim, + self.mlp_hidden_dim, + bias=True, + return_bias=False, + quant_config=quant_config, + prefix="proj_mlp", + ) self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + self.proj_out = ReplicatedLinear( + dim + self.mlp_hidden_dim, + dim, + bias=True, + return_bias=False, + quant_config=quant_config, + prefix="proj_out", + ) # SP handling is delegated to LongCatImageAttention via text_seq_len kwarg self.attn = LongCatImageAttention( @@ -526,6 +562,7 @@ def __init__( out_dim=dim, bias=True, eps=1e-6, + quant_config=quant_config, pre_only=True, ) @@ -603,6 +640,7 @@ class LongCatImageTransformer2DModel(nn.Module): def __init__( self, od_config: OmniDiffusionConfig, + quant_config: "QuantizationConfig | None" = None, ): super().__init__() model_config = od_config.tf_model_config @@ -627,8 +665,22 @@ def __init__( self.time_embed = LongCatImageTimestepEmbeddings(embedding_dim=self.inner_dim) - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) + self.context_embedder = ReplicatedLinear( + joint_attention_dim, + self.inner_dim, + bias=True, + return_bias=False, + quant_config=quant_config, + prefix="context_embedder", + ) + self.x_embedder = ReplicatedLinear( + in_channels, + self.inner_dim, + bias=True, + return_bias=False, + quant_config=quant_config, + prefix="x_embedder", + ) self.transformer_blocks = nn.ModuleList( [ @@ -637,6 +689,7 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + quant_config=quant_config, ) for i in range(num_layers) ] @@ -649,13 +702,21 @@ def __init__( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, + quant_config=quant_config, ) for i in range(num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.proj_out = ReplicatedLinear( + self.inner_dim, + patch_size * patch_size * self.out_channels, + bias=True, + return_bias=False, + quant_config=quant_config, + prefix="proj_out", + ) self.gradient_checkpointing = False diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 76d3efa2f85..6605a8bcb54 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -236,7 +236,7 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = LongCatImageTransformer2DModel(od_config=od_config) + self.transformer = LongCatImageTransformer2DModel(od_config=od_config, quant_config=od_config.quantization_config) self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 @@ -681,4 +681,7 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loaded_weights = loader.load_weights(weights) + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + return loaded_weights diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py index 7eccf68636a..81ee2dcbe58 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image_edit.py @@ -257,7 +257,10 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = LongCatImageTransformer2DModel(od_config=od_config) + self.transformer = LongCatImageTransformer2DModel( + od_config=od_config, + quant_config=od_config.quantization_config, + ) self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 @@ -714,4 +717,7 @@ def forward( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load weights using AutoWeightsLoader for vLLM integration.""" loader = AutoWeightsLoader(self) - return loader.load_weights(weights) + loaded_weights = loader.load_weights(weights) + loaded_weights |= {f"vae.{name}" for name, _ in self.vae.named_parameters()} + loaded_weights |= {f"text_encoder.{name}" for name, _ in self.text_encoder.named_parameters()} + return loaded_weights From 93611b71ab4acf6f5c626ce01dde799655cd36f1 Mon Sep 17 00:00:00 2001 From: lcukyfuture Date: Thu, 9 Apr 2026 16:30:08 +0800 Subject: [PATCH 2/6] [Style] Fix ruff E501 line-too-long errors in longcat_image models Signed-off-by: lcukyfuture --- .../longcat_image_transformer.py | 26 +++++++++++++++---- .../longcat_image/pipeline_longcat_image.py | 4 ++- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index f4c5a180be2..fc977c6d8a1 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -40,14 +40,26 @@ class FeedForward(nn.Module): - def __init__(self, dim: int, dim_out: int | None = None, mult: int = 4, bias: bool = True, quant_config: "QuantizationConfig | None" = None, prefix: str = ""): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + bias: bool = True, + quant_config: "QuantizationConfig | None" = None, + prefix: str = "", + ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim - self.w_in = ColumnParallelLinear(dim, inner_dim, bias=bias, return_bias=False, quant_config=quant_config, prefix=f"{prefix}.w_in") + self.w_in = ColumnParallelLinear( + dim, inner_dim, bias=bias, return_bias=False, quant_config=quant_config, prefix=f"{prefix}.w_in" + ) self.act = get_act_fn("gelu_pytorch_tanh") - self.w_out = RowParallelLinear(inner_dim, dim_out, bias=bias, return_bias=False, quant_config=quant_config, prefix=f"{prefix}.w_out") + self.w_out = RowParallelLinear( + inner_dim, dim_out, bias=bias, return_bias=False, quant_config=quant_config, prefix=f"{prefix}.w_out" + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.w_in(hidden_states) @@ -101,7 +113,9 @@ def __init__( ) if not self.pre_only: - self.to_out = RowParallelLinear(self.inner_dim, self.out_dim, bias=out_bias, quant_config=quant_config, prefix="to_out") + self.to_out = RowParallelLinear( + self.inner_dim, self.out_dim, bias=out_bias, quant_config=quant_config, prefix="to_out" + ) if self.added_kv_proj_dim is not None: self.norm_added_q = RMSNorm(dim_head, eps=eps) @@ -116,7 +130,9 @@ def __init__( prefix="add_kv_proj", ) - self.to_add_out = RowParallelLinear(self.inner_dim, query_dim, bias=out_bias, quant_config=quant_config, prefix="to_add_out") + self.to_add_out = RowParallelLinear( + self.inner_dim, query_dim, bias=out_bias, quant_config=quant_config, prefix="to_add_out" + ) self.attn = Attention( num_heads=heads, diff --git a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py index 6605a8bcb54..6bcfa89bd1d 100644 --- a/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py +++ b/vllm_omni/diffusion/models/longcat_image/pipeline_longcat_image.py @@ -236,7 +236,9 @@ def __init__( self.vae = AutoencoderKL.from_pretrained(model, subfolder="vae", local_files_only=local_files_only).to( self.device ) - self.transformer = LongCatImageTransformer2DModel(od_config=od_config, quant_config=od_config.quantization_config) + self.transformer = LongCatImageTransformer2DModel( + od_config=od_config, quant_config=od_config.quantization_config + ) self.tokenizer = AutoTokenizer.from_pretrained(model, subfolder="tokenizer", local_files_only=local_files_only) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 From 90c7ebe11d63a7022d4d138e3fac7839c56c63e7 Mon Sep 17 00:00:00 2001 From: lcukyfuture Date: Thu, 9 Apr 2026 16:51:44 +0800 Subject: [PATCH 3/6] [Feat] FP8 quantization: add benchmarks and quality tests for longcat-image Signed-off-by: lcukyfuture --- benchmarks/diffusion/quantization_quality.py | 295 ++++++++++++------ .../quantization/test_quantization_quality.py | 72 ++++- 2 files changed, 265 insertions(+), 102 deletions(-) diff --git a/benchmarks/diffusion/quantization_quality.py b/benchmarks/diffusion/quantization_quality.py index 4a916e7ea62..2025caf8a1c 100644 --- a/benchmarks/diffusion/quantization_quality.py +++ b/benchmarks/diffusion/quantization_quality.py @@ -34,15 +34,6 @@ --height 720 --width 1280 \ --num-frames 81 --num-inference-steps 40 --seed 42 -Multiple quantization methods: - python benchmarks/diffusion/quantization_quality.py \ - --model Tongyi-MAI/Z-Image-Turbo \ - --task t2i \ - --quantization fp8 int8 bitsandbytes \ - --prompts "a cup of coffee on the table" \ - --height 1024 --width 1024 \ - --num-inference-steps 50 --seed 42 - Output directory structure (--output-dir, default: ./quant_bench_output): quant_bench_output/ baseline/ # BF16 outputs @@ -52,6 +43,9 @@ import argparse import gc +import hashlib +import json +import re import time from pathlib import Path @@ -145,13 +139,83 @@ def _build_omni_kwargs(args, quantization=None): return kwargs +def _sanitize_label(text: str) -> str: + """Convert a display label into a filesystem-safe slug.""" + slug = re.sub(r"[^A-Za-z0-9._-]+", "_", text).strip("_") + return slug or "quantized" + + +def _short_ignored_layer_name(layer_name: str) -> str: + """Build a readable short name for ignored_layers labels.""" + short = layer_name + replacements = { + "single_transformer_blocks": "single_blocks", + "transformer_blocks": "blocks", + "context_embedder": "context", + "x_embedder": "x", + } + for src, dst in replacements.items(): + short = short.replace(src, dst) + short = short.replace(".", "_") + return _sanitize_label(short) + + +def _build_quantization_label(spec) -> str: + """Build a short human-readable label for a quantization config.""" + if isinstance(spec, str): + return spec + + if isinstance(spec, dict): + method = str(spec.get("method", "quant")) + ignored_layers = spec.get("ignored_layers") or spec.get("modules_to_not_convert") or [] + if ignored_layers: + ignored_suffix = "_".join(_short_ignored_layer_name(layer) for layer in ignored_layers) + label = f"{method}_skip_{ignored_suffix}" + else: + label = method + + extra_keys = sorted(k for k in spec.keys() if k not in {"method", "ignored_layers", "modules_to_not_convert"}) + if extra_keys: + canonical = json.dumps(spec, sort_keys=True, separators=(",", ":")) + label = f"{label}_{hashlib.sha1(canonical.encode('utf-8')).hexdigest()[:8]}" + + return label + + raise TypeError(f"Unsupported quantization spec type: {type(spec)!r}") + + +def _parse_quantization_spec(raw_spec: str): + """Parse a CLI quantization spec into (display_label, slug, config).""" + raw_spec = raw_spec.strip() + if raw_spec.startswith("{"): + spec = json.loads(raw_spec) + else: + spec = raw_spec + + label = _build_quantization_label(spec) + slug = _sanitize_label(label) + return label, slug, spec + + +def _get_gpu_memory_gib(device_index: int = 0) -> float: + """Return current GPU memory used (GiB) across all processes on the device.""" + try: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + info = pynvml.nvmlDeviceGetMemoryInfo(handle) + return info.used / (1024**3) + except Exception: + return torch.cuda.memory_allocated(device_index) / (1024**3) + + def _generate_image(omni, args, prompt, seed): """Generate a single image and return (PIL.Image, time_seconds, memory_gib).""" from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) - torch.cuda.reset_peak_memory_stats() start = time.perf_counter() outputs = omni.generate( {"prompt": prompt}, @@ -163,14 +227,51 @@ def _generate_image(omni, args, prompt, seed): ), ) elapsed = time.perf_counter() - start - peak_mem = torch.cuda.max_memory_allocated() / (1024**3) + peak_mem = _get_gpu_memory_gib() first = outputs[0] - req_out = first.request_output[0] if hasattr(first, "request_output") else first + req_out = first.request_output if hasattr(first, "request_output") else first + if isinstance(req_out, (list, tuple)): + req_out = req_out[0] img = req_out.images[0] return img, elapsed, peak_mem +def _generate_image_edit(omni, args, prompt, seed): + """Generate an edited image and return (PIL.Image, time_seconds, memory_gib).""" + import PIL.Image + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + from vllm_omni.platforms import current_omni_platform + + if args.image: + image = PIL.Image.open(args.image).convert("RGB") + else: + from vllm.assets.image import ImageAsset + image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB") + generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) + start = time.perf_counter() + outputs = omni.generate( + { + "prompt": prompt, + "multi_modal_data": {"image": image}, + }, + OmniDiffusionSamplingParams( + height=args.height, + width=args.width, + generator=generator, + num_inference_steps=args.num_inference_steps, + ), + ) + elapsed = time.perf_counter() - start + peak_mem = _get_gpu_memory_gib() + + first = outputs[0] + req_out = first.request_output if hasattr(first, "request_output") else first + if isinstance(req_out, (list, tuple)): + req_out = req_out[0] + return req_out.images[0], elapsed, peak_mem + + def _generate_video(omni, args, prompt, seed): """Generate a video and return (np.ndarray [F,H,W,C], time_seconds, memory_gib).""" from vllm_omni.inputs.data import OmniDiffusionSamplingParams @@ -178,7 +279,6 @@ def _generate_video(omni, args, prompt, seed): from vllm_omni.platforms import current_omni_platform generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) - torch.cuda.reset_peak_memory_stats() start = time.perf_counter() outputs = omni.generate( {"prompt": prompt, "negative_prompt": ""}, @@ -192,7 +292,7 @@ def _generate_video(omni, args, prompt, seed): ), ) elapsed = time.perf_counter() - start - peak_mem = torch.cuda.max_memory_allocated() / (1024**3) + peak_mem = _get_gpu_memory_gib() first = outputs[0] if hasattr(first, "request_output") and isinstance(first.request_output, list): @@ -225,6 +325,8 @@ def _generate_video(omni, args, prompt, seed): def _unload_omni(omni): """Delete Omni instance and free GPU memory.""" + if hasattr(omni, "close"): + omni.close() del omni gc.collect() if torch.cuda.is_available(): @@ -239,13 +341,13 @@ def run_benchmark(args): output_dir.mkdir(parents=True, exist_ok=True) is_video = args.task == "t2v" + is_edit = args.task == "image_edit" + prompts = args.prompts seed = args.seed - # Determine configs to benchmark - configs = [] # list of (label, quantization_method) - for method in args.quantization: - configs.append((method, method)) + # Determine config to benchmark + config_label, config_slug, quant_spec = _parse_quantization_spec(args.quantization) # --- Baseline run --- print("\n" + "=" * 60) @@ -259,6 +361,8 @@ def run_benchmark(args): print(f" Generating: {prompt[:60]}...") if is_video: out, t, mem = _generate_video(omni_bl, args, prompt, seed) + elif is_edit: + out, t, mem = _generate_image_edit(omni_bl, args, prompt, seed) else: out, t, mem = _generate_image(omni_bl, args, prompt, seed) baseline_outputs[prompt] = (out, t, mem) @@ -283,68 +387,65 @@ def run_benchmark(args): else: out.save(bl_dir / f"prompt_{i}.png") - # --- Quantized runs --- - all_results = [] # list of dicts + # --- Quantized run --- + print(f"\n{'=' * 60}") + print(f"Running: {config_label}...") + print("=" * 60) - for config_label, quant_method in configs: - print(f"\n{'=' * 60}") - print(f"Running: {config_label}...") - print("=" * 60) + qt_kwargs = _build_omni_kwargs(args, quantization=quant_spec) + omni_qt = Omni(**qt_kwargs) - qt_kwargs = _build_omni_kwargs(args, quantization=quant_method) - omni_qt = Omni(**qt_kwargs) + qt_outputs = {} + for prompt in prompts: + print(f" Generating: {prompt[:60]}...") + if is_video: + out, t, mem = _generate_video(omni_qt, args, prompt, seed) + elif is_edit: + out, t, mem = _generate_image_edit(omni_qt, args, prompt, seed) + else: + out, t, mem = _generate_image(omni_qt, args, prompt, seed) + qt_outputs[prompt] = (out, t, mem) - qt_outputs = {} - for prompt in prompts: - print(f" Generating: {prompt[:60]}...") - if is_video: - out, t, mem = _generate_video(omni_qt, args, prompt, seed) - else: - out, t, mem = _generate_image(omni_qt, args, prompt, seed) - qt_outputs[prompt] = (out, t, mem) + qt_avg_time = np.mean([v[1] for v in qt_outputs.values()]) + qt_mem = qt_outputs[prompts[0]][2] + _unload_omni(omni_qt) - qt_avg_time = np.mean([v[1] for v in qt_outputs.values()]) - qt_mem = qt_outputs[prompts[0]][2] - _unload_omni(omni_qt) + # Save quantized outputs + qt_dir = output_dir / config_slug + qt_dir.mkdir(parents=True, exist_ok=True) - # Save quantized outputs - qt_dir = output_dir / config_label.replace(" ", "_") - qt_dir.mkdir(parents=True, exist_ok=True) + # Compute LPIPS per prompt + per_prompt = [] + for i, prompt in enumerate(prompts): + bl_out = baseline_outputs[prompt][0] + qt_out = qt_outputs[prompt][0] + if is_video: + lpips_score = compute_lpips_video(bl_out, qt_out, net=args.lpips_net) + try: + from diffusers.utils import export_to_video - # Compute LPIPS per prompt - per_prompt = [] - for i, prompt in enumerate(prompts): - bl_out = baseline_outputs[prompt][0] - qt_out = qt_outputs[prompt][0] - if is_video: - lpips_score = compute_lpips_video(bl_out, qt_out, net=args.lpips_net) - try: - from diffusers.utils import export_to_video - - frames_list = list(qt_out) if isinstance(qt_out, np.ndarray) and qt_out.ndim == 4 else qt_out - export_to_video(frames_list, str(qt_dir / f"prompt_{i}.mp4"), fps=args.fps) - except ImportError: - np.save(qt_dir / f"prompt_{i}.npy", qt_out) - else: - lpips_score = compute_lpips_images([bl_out], [qt_out], net=args.lpips_net)[0] - qt_out.save(qt_dir / f"prompt_{i}.png") - per_prompt.append({"prompt": prompt, "lpips": lpips_score}) - - mean_lpips = np.mean([p["lpips"] for p in per_prompt]) - speedup = bl_avg_time / qt_avg_time if qt_avg_time > 0 else float("inf") - mem_reduction = (bl_mem - qt_mem) / bl_mem * 100 - - all_results.append( - { - "config": config_label, - "avg_time": qt_avg_time, - "speedup": speedup, - "memory_gib": qt_mem, - "mem_reduction_pct": mem_reduction, - "mean_lpips": mean_lpips, - "per_prompt": per_prompt, - } - ) + frames_list = list(qt_out) if isinstance(qt_out, np.ndarray) and qt_out.ndim == 4 else qt_out + export_to_video(frames_list, str(qt_dir / f"prompt_{i}.mp4"), fps=args.fps) + except ImportError: + np.save(qt_dir / f"prompt_{i}.npy", qt_out) + else: + lpips_score = compute_lpips_images([bl_out], [qt_out], net=args.lpips_net)[0] + qt_out.save(qt_dir / f"prompt_{i}.png") + per_prompt.append({"prompt": prompt, "lpips": lpips_score}) + + mean_lpips = np.mean([p["lpips"] for p in per_prompt]) + speedup = bl_avg_time / qt_avg_time if qt_avg_time > 0 else float("inf") + mem_reduction = (bl_mem - qt_mem) / bl_mem * 100 + + result = { + "config": config_label, + "avg_time": qt_avg_time, + "speedup": speedup, + "memory_gib": qt_mem, + "mem_reduction_pct": mem_reduction, + "mean_lpips": mean_lpips, + "per_prompt": per_prompt, + } # --- Print results --- print("\n\n") @@ -367,12 +468,11 @@ def run_benchmark(args): lines.append("| Config | Avg Time | Speedup | Memory (GiB) | Mem Reduction | Mean LPIPS |") lines.append("|--------|----------|---------|--------------|---------------|------------|") lines.append(f"| BF16 baseline | {bl_avg_time:.2f}s | 1.00x | {bl_mem:.2f} | — | (ref) |") - for r in all_results: - lines.append( - f"| {r['config']} | {r['avg_time']:.2f}s | {r['speedup']:.2f}x " - f"| {r['memory_gib']:.2f} | {r['mem_reduction_pct']:.0f}% " - f"| {r['mean_lpips']:.4f} |" - ) + lines.append( + f"| {result['config']} | {result['avg_time']:.2f}s | {result['speedup']:.2f}x " + f"| {result['memory_gib']:.2f} | {result['mem_reduction_pct']:.0f}% " + f"| {result['mean_lpips']:.4f} |" + ) lines.append("") lines.append("> LPIPS < 0.01 = imperceptible, > 0.1 = clearly noticeable.") lines.append("") @@ -381,19 +481,11 @@ def run_benchmark(args): if len(prompts) > 1: lines.append("### Per-Prompt LPIPS") lines.append("") - header = "| Prompt |" - sep = "|--------|" - for r in all_results: - header += f" {r['config']} |" - sep += "--------|" - lines.append(header) - lines.append(sep) + lines.append(f"| Prompt | {result['config']} |") + lines.append("|--------|--------|") for i, prompt in enumerate(prompts): short = prompt[:50] + "..." if len(prompt) > 50 else prompt - row = f"| {short} |" - for r in all_results: - row += f" {r['per_prompt'][i]['lpips']:.4f} |" - lines.append(row) + lines.append(f"| {short} | {result['per_prompt'][i]['lpips']:.4f} |") lines.append("") md = "\n".join(lines) @@ -404,9 +496,7 @@ def run_benchmark(args): results_path.write_text(md, encoding="utf-8") print(f"\nResults saved to {results_path}") print(f"Baseline outputs in {bl_dir}") - for r in all_results: - qt_dir = output_dir / r["config"].replace(" ", "_") - print(f"Quantized outputs in {qt_dir}") + print(f"Quantized outputs in {qt_dir}") def parse_args(): @@ -418,14 +508,17 @@ def parse_args(): parser.add_argument( "--task", default="t2i", - choices=["t2i", "t2v"], - help="Task type: t2i (text-to-image) or t2v (text-to-video).", + choices=["t2i", "t2v", "image_edit"], + help="Task type: t2i (text-to-image), t2v (text-to-video), or image_edit (image editing).", ) parser.add_argument( "--quantization", - nargs="+", required=True, - help="One or more quantization methods to benchmark (e.g. fp8 int8 bitsandbytes).", + help=( + "Quantization spec to benchmark. " + "Can be a method name (e.g. fp8) or a JSON object string " + '(e.g. \'{"method":"fp8","ignored_layers":["proj_out"]}\').' + ), ) parser.add_argument( "--prompts", @@ -433,6 +526,12 @@ def parse_args(): default=["a cup of coffee on the table"], help="One or more prompts to generate.", ) + parser.add_argument( + "--image", + type=str, + default=None, + help="Path to input image for image_edit task. Defaults to a standard vllm test image.", + ) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--height", type=int, default=1024) parser.add_argument("--width", type=int, default=1024) diff --git a/tests/diffusion/quantization/test_quantization_quality.py b/tests/diffusion/quantization/test_quantization_quality.py index a937a648616..aa1cf1c3e42 100644 --- a/tests/diffusion/quantization/test_quantization_quality.py +++ b/tests/diffusion/quantization/test_quantization_quality.py @@ -46,7 +46,7 @@ class QualityTestConfig: id: str # pytest ID, e.g. "fp8_z_image" model: str # HF model name quantization: str # quantization method, e.g. "fp8" - task: str # "t2i" or "t2v" + task: str # "t2i", "t2v", or "image_edit" prompt: str # generation prompt max_lpips: float # fail threshold — higher = more lenient height: int = 1024 @@ -55,6 +55,7 @@ class QualityTestConfig: num_frames: int = 5 # only for t2v seed: int = 42 gpu: str = "H100" # minimum GPU requirement + image: str | None = None # vllm ImageAsset name for image_edit task # Add new quantization methods / models here. @@ -88,6 +89,27 @@ class QualityTestConfig: seed=142, num_inference_steps=20, ), + QualityTestConfig( + id="fp8_longcat_image", + model="meituan-longcat/LongCat-Image", + quantization="fp8", + task="t2i", + prompt="a cup of coffee on a wooden table, morning light", + max_lpips=0.15, + seed=42, + num_inference_steps=20, + ), + QualityTestConfig( + id="fp8_longcat_image_edit", + model="meituan-longcat/LongCat-Image-Edit", + quantization="fp8", + task="image_edit", + prompt="Transform this modern image into a cinematic animation style with vibrant colors.", + image="2560px-Gfp-wisconsin-madison-the-nature-boardwalk", + max_lpips=0.20, + seed=42, + num_inference_steps=20, + ), ] @@ -118,7 +140,42 @@ def _generate_image(omni, config: QualityTestConfig): peak_mem = torch.cuda.max_memory_allocated() / (1024**3) first = outputs[0] - req_out = first.request_output[0] if hasattr(first, "request_output") else first + req_out = first.request_output if hasattr(first, "request_output") else first + if isinstance(req_out, (list, tuple)): + req_out = req_out[0] + return req_out.images[0], peak_mem + + +def _generate_image_edit(omni, config: QualityTestConfig): + """Generate an edited image, return (PIL.Image, peak_mem_gib).""" + from vllm.assets.image import ImageAsset + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + from vllm_omni.platforms import current_omni_platform + + image = ImageAsset(config.image).pil_image.convert("RGB") + generator = torch.Generator( + device=current_omni_platform.device_type, + ).manual_seed(config.seed) + torch.cuda.reset_peak_memory_stats() + + outputs = omni.generate( + { + "prompt": config.prompt, + "multi_modal_data": {"image": image}, + }, + OmniDiffusionSamplingParams( + height=config.height, + width=config.width, + generator=generator, + num_inference_steps=config.num_inference_steps, + ), + ) + + peak_mem = torch.cuda.max_memory_allocated() / (1024**3) + first = outputs[0] + req_out = first.request_output if hasattr(first, "request_output") else first + if isinstance(req_out, (list, tuple)): + req_out = req_out[0] return req_out.images[0], peak_mem @@ -177,12 +234,14 @@ def _compute_lpips(baseline, quantized, task: str) -> float: compute_lpips_video, ) - if task == "t2i": + if task in ("t2i", "image_edit"): return compute_lpips_images([baseline], [quantized])[0] return compute_lpips_video(baseline, quantized) def _unload(omni): + if hasattr(omni, "close"): + omni.close() del omni gc.collect() if torch.cuda.is_available(): @@ -207,7 +266,12 @@ def test_quantization_quality(config: QualityTestConfig): """Validate that quantized output stays within LPIPS threshold of BF16.""" from vllm_omni.entrypoints.omni import Omni - generate_fn = _generate_video if config.task == "t2v" else _generate_image + if config.task == "t2v": + generate_fn = _generate_video + elif config.task == "image_edit": + generate_fn = _generate_image_edit + else: + generate_fn = _generate_image # --- BF16 baseline --- omni_bl = Omni(model=config.model) From 991a11e6f3af9d9c21924c3978ca545d88fc6500 Mon Sep 17 00:00:00 2001 From: lcukyfuture Date: Thu, 9 Apr 2026 18:03:45 +0800 Subject: [PATCH 4/6] [Style] Fix ruff import sort and format in benchmarks and test Signed-off-by: lcukyfuture --- benchmarks/diffusion/quantization_quality.py | 2 ++ .../quantization/test_quantization_quality.py | 36 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/benchmarks/diffusion/quantization_quality.py b/benchmarks/diffusion/quantization_quality.py index 2025caf8a1c..48d5fc60b4f 100644 --- a/benchmarks/diffusion/quantization_quality.py +++ b/benchmarks/diffusion/quantization_quality.py @@ -240,6 +240,7 @@ def _generate_image(omni, args, prompt, seed): def _generate_image_edit(omni, args, prompt, seed): """Generate an edited image and return (PIL.Image, time_seconds, memory_gib).""" import PIL.Image + from vllm_omni.inputs.data import OmniDiffusionSamplingParams from vllm_omni.platforms import current_omni_platform @@ -247,6 +248,7 @@ def _generate_image_edit(omni, args, prompt, seed): image = PIL.Image.open(args.image).convert("RGB") else: from vllm.assets.image import ImageAsset + image = ImageAsset("2560px-Gfp-wisconsin-madison-the-nature-boardwalk").pil_image.convert("RGB") generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(seed) start = time.perf_counter() diff --git a/tests/diffusion/quantization/test_quantization_quality.py b/tests/diffusion/quantization/test_quantization_quality.py index 7cef1d7240b..f2ada1b2bf8 100644 --- a/tests/diffusion/quantization/test_quantization_quality.py +++ b/tests/diffusion/quantization/test_quantization_quality.py @@ -148,6 +148,42 @@ def _generate_image(omni, config: QualityTestConfig): raise ValueError("Could not extract image from output.") +def _generate_image_edit(omni, config: QualityTestConfig): + """Generate an edited image, return (PIL.Image, peak_mem_gib).""" + from vllm.assets.image import ImageAsset + + from vllm_omni.inputs.data import OmniDiffusionSamplingParams + from vllm_omni.platforms import current_omni_platform + + image = ImageAsset(config.image).pil_image.convert("RGB") + generator = torch.Generator( + device=current_omni_platform.device_type, + ).manual_seed(config.seed) + torch.cuda.reset_peak_memory_stats() + + outputs = omni.generate( + { + "prompt": config.prompt, + "multi_modal_data": {"image": image}, + }, + OmniDiffusionSamplingParams( + height=config.height, + width=config.width, + generator=generator, + num_inference_steps=config.num_inference_steps, + ), + ) + + peak_mem = torch.cuda.max_memory_allocated() / (1024**3) + first = outputs[0] + if hasattr(first, "images") and first.images: + return first.images[0], peak_mem + inner = first.request_output + if inner is not None and hasattr(inner, "images") and inner.images: + return inner.images[0], peak_mem + raise ValueError("Could not extract image from output.") + + def _generate_video(omni, config: QualityTestConfig): """Generate a video, return (np.ndarray [F,H,W,C], peak_mem_gib).""" from vllm_omni.inputs.data import OmniDiffusionSamplingParams From 085067024b7499d6c665839fc42b049c9f046bfa Mon Sep 17 00:00:00 2001 From: lcukyfuture Date: Mon, 27 Apr 2026 10:52:05 +0800 Subject: [PATCH 5/6] Address LongCat FP8 quantization review Signed-off-by: lcukyfuture --- .../longcat_image_transformer.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index a227b23bcff..e1af10b9330 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -85,6 +85,7 @@ def __init__( context_pre_only: bool | None = None, pre_only: bool = False, quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() self.parallel_config = parallel_config @@ -109,12 +110,16 @@ def __init__( total_num_heads=self.heads, bias=bias, quant_config=quant_config, - prefix="to_qkv", + prefix=f"{prefix}.to_qkv", ) if not self.pre_only: self.to_out = RowParallelLinear( - self.inner_dim, self.out_dim, bias=out_bias, quant_config=quant_config, prefix="to_out" + self.inner_dim, + self.out_dim, + bias=out_bias, + quant_config=quant_config, + prefix=f"{prefix}.to_out", ) if self.added_kv_proj_dim is not None: @@ -127,11 +132,15 @@ def __init__( total_num_heads=self.heads, bias=added_proj_bias, quant_config=quant_config, - prefix="add_kv_proj", + prefix=f"{prefix}.add_kv_proj", ) self.to_add_out = RowParallelLinear( - self.inner_dim, query_dim, bias=out_bias, quant_config=quant_config, prefix="to_add_out" + self.inner_dim, + query_dim, + bias=out_bias, + quant_config=quant_config, + prefix=f"{prefix}.to_add_out", ) self.attn = Attention( @@ -349,6 +358,7 @@ def __init__( qk_norm: str = "rms_norm", eps: float = 1e-6, quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() @@ -367,13 +377,16 @@ def __init__( bias=True, eps=eps, quant_config=quant_config, + prefix=f"{prefix}.attn", ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="ff") + self.ff = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix=f"{prefix}.ff") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix="ff_context") + self.ff_context = FeedForward( + dim=dim, dim_out=dim, quant_config=quant_config, prefix=f"{prefix}.ff_context" + ) def forward( self, @@ -546,6 +559,7 @@ def __init__( attention_head_dim: int, mlp_ratio: float = 4.0, quant_config: "QuantizationConfig | None" = None, + prefix: str = "", ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) @@ -557,7 +571,7 @@ def __init__( bias=True, return_bias=False, quant_config=quant_config, - prefix="proj_mlp", + prefix=f"{prefix}.proj_mlp", ) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = ReplicatedLinear( @@ -566,7 +580,7 @@ def __init__( bias=True, return_bias=False, quant_config=quant_config, - prefix="proj_out", + prefix=f"{prefix}.proj_out", ) # SP handling is delegated to LongCatImageAttention via text_seq_len kwarg @@ -580,6 +594,7 @@ def __init__( eps=1e-6, quant_config=quant_config, pre_only=True, + prefix=f"{prefix}.attn", ) def forward( @@ -707,6 +722,7 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, quant_config=quant_config, + prefix=f"transformer_blocks.{i}", ) for i in range(num_layers) ] @@ -720,20 +736,14 @@ def __init__( num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, quant_config=quant_config, + prefix=f"single_transformer_blocks.{i}", ) for i in range(num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = ReplicatedLinear( - self.inner_dim, - patch_size * patch_size * self.out_channels, - bias=True, - return_bias=False, - quant_config=quant_config, - prefix="proj_out", - ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False From 325ca3b054b2a7df04cd91fec7618458b460aa35 Mon Sep 17 00:00:00 2001 From: lcukyfuture Date: Mon, 27 Apr 2026 10:57:27 +0800 Subject: [PATCH 6/6] Format LongCat FP8 review fix Signed-off-by: lcukyfuture --- .../models/longcat_image/longcat_image_transformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py index e1af10b9330..8f16e986e60 100644 --- a/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py +++ b/vllm_omni/diffusion/models/longcat_image/longcat_image_transformer.py @@ -384,9 +384,7 @@ def __init__( self.ff = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix=f"{prefix}.ff") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward( - dim=dim, dim_out=dim, quant_config=quant_config, prefix=f"{prefix}.ff_context" - ) + self.ff_context = FeedForward(dim=dim, dim_out=dim, quant_config=quant_config, prefix=f"{prefix}.ff_context") def forward( self,