diff --git a/docs/.nav.yml b/docs/.nav.yml index 0228c713993..55283f0e8b1 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -76,6 +76,7 @@ nav: - Quantization: - Overview: user_guide/quantization/overview.md - Online Quantization: user_guide/quantization/online.md + - Quantized KV Cache: user_guide/quantization/quantized_kvcache.md - FP8 W8A8: user_guide/quantization/fp8.md - Int8 W8A8: user_guide/quantization/int8.md - ModelOpt: user_guide/quantization/modelopt.md diff --git a/docs/user_guide/quantization/overview.md b/docs/user_guide/quantization/overview.md index ed7a1b8cfb4..d4e3ce377e9 100644 --- a/docs/user_guide/quantization/overview.md +++ b/docs/user_guide/quantization/overview.md @@ -1,4 +1,4 @@ -# Quantization +# Quantization vLLM-Omni exposes quantization through the unified `quantization_config` path. The same configuration entrypoint is used across diffusion-only models, @@ -10,6 +10,7 @@ type has a different quantization scope. | Mode | Guide | Description | Methods | |------|-------|-------------|---------| | Online quantization | [Online Quantization](online.md) | vLLM-Omni computes quantized weights and scales while loading the model. | FP8 W8A8, Int8 W8A8 | +| Runtime attention quantization | [Quantized KV Cache](quantized_kvcache.md) | vLLM-Omni dynamically quantizes eligible diffusion Flash Attention tensors during inference. | FP8 FA | | Pre-quantized checkpoints | Method-specific guides | The checkpoint or an offline quantizer provides quantized weights and scales before serving. | ModelOpt, GGUF, AutoRound, msModelSlim, serialized Int8 | ## Hardware Support diff --git a/docs/user_guide/quantization/quantized_kvcache.md b/docs/user_guide/quantization/quantized_kvcache.md new file mode 100644 index 00000000000..62815eb4b38 --- /dev/null +++ b/docs/user_guide/quantization/quantized_kvcache.md @@ -0,0 +1,115 @@ +# Quantized KV Cache + +## Overview + +In DiT-based image and video generation, Flash Attention can take a large share +of denoising time, especially for high-resolution or long-frame workloads. +vLLM-Omni supports online FP8 quantization for eligible diffusion Flash +Attention (FA) to reduce FA latency while keeping model weights in their +original dtype. + +This feature is configured through `kv_cache_dtype`, matching the option name +used by vLLM's language-model KV-cache quantization. In vLLM-Omni diffusion +pipelines, however, it is a runtime FA path: Q/K/V tensors are dynamically +quantized before the attention operator. It does not quantize model weights and +is separate from [FP8 W8A8](fp8.md), [Int8 W8A8](int8.md), or pre-quantized +checkpoint formats. + +If `kv_cache_dtype` is not set, behavior is unchanged and attention runs in the +native dtype. + +## Hardware Support + +| Device | FP8 FA | +|--------|--------| +| Ascend NPU | ✅ | +| NVIDIA GPU | ❌ | +| AMD ROCm | ❌ | +| Intel XPU | ❌ | + +Legend: `✅` supported, `❌` unsupported. + +FP8 FA is currently implemented only for the NPU Flash Attention backend. Other +backends do not support `kv_cache_dtype="fp8"` for diffusion attention and fall +back to native dtype execution. + +## Model Type Support + +### Diffusion Model + +| Model | Scope | Status | Notes | +|-------|-------|--------|-------| +| Wan2.2 | Eligible DiT full-attention FA on Ascend NPU | Tested | Compare quality and latency against a BF16 baseline before production use | +| Other diffusion models | Eligible DiT full-attention FA on Ascend NPU | Not tested | You can try `kv_cache_dtype="fp8"`; tune `kv_cache_skip_steps` and `kv_cache_skip_layers` when higher precision is needed | + +### Multi-Stage Omni/TTS Model (Qwen3-Omni, Qwen3-TTS) + +Not tested for FP8 FA. Treat any use as experimental unless a model-specific +guide documents support. + +### Multi-Stage Diffusion Model (BAGEL, GLM-Image) + +Not tested. If the diffusion stage uses the same NPU Flash Attention backend, +`kv_cache_dtype` may apply in theory; validate quality and latency for each +stage and model. + +## Configuration + +Offline diffusion example: + +```bash +python examples/offline_inference/image_to_video/image_to_video.py \ + --model \ + --prompt "A cat sitting on a surfboard at the beach" \ + --height 1280 \ + --width 720 \ + --num-frames 61 \ + --num-inference-steps 4 \ + --ulysses-degree 4 \ + --vae-patch-parallel-size 4 \ + --kv-cache-dtype fp8 \ + --kv-cache-skip-steps "0,1" \ + --kv-cache-skip-layers "0-2" +``` + +Online serving: + +```bash +vllm serve --omni --kv-cache-dtype fp8 +``` + +Stage config: + +```yaml +stage_args: + - stage_id: 0 + stage_type: diffusion + engine_args: + model_stage: dit + kv_cache_dtype: "fp8" + kv_cache_skip_steps: "0,1" + kv_cache_skip_layers: "0-2" +``` + +## Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `kv_cache_dtype` | str \| None | `None` | Set to `"fp8"` to enable dynamic FP8 FA on supported attention backends | +| `kv_cache_skip_steps` | str \| None | `None` | Denoising step selector to keep in native dtype, for example `"0,1,4-6"` | +| `kv_cache_skip_layers` | str \| None | `None` | Transformer layer selector to keep in native dtype, for example `"0-2,10"` | + +Selectors use comma-separated integers and inclusive ranges. Listed steps or +layers skip FP8 FA; all other eligible full-attention forwards use the FP8 path. + +## Validation and Notes + +1. Compare generated images or videos against a BF16 baseline with the same + seed, prompt, resolution, frame count, and denoising steps. +2. Use `kv_cache_skip_steps` for denoising steps where quality is more + sensitive. +3. Use `kv_cache_skip_layers` for transformer layers that show visible quality + regressions. +4. Report both latency and quality results when enabling this option for a new + model. For image or video models, include visual comparison and quantitative + metrics when available, such as PSNR or SSIM. diff --git a/examples/offline_inference/image_to_video/image_to_video.py b/examples/offline_inference/image_to_video/image_to_video.py index e46db4de456..386afea3149 100644 --- a/examples/offline_inference/image_to_video/image_to_video.py +++ b/examples/offline_inference/image_to_video/image_to_video.py @@ -102,6 +102,24 @@ def parse_args() -> argparse.Namespace: choices=["unipc", "euler"], help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.", ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default=None, + help="Config-level KV cache dtype (e.g. float8_e4m3fn).", + ) + parser.add_argument( + "--kv-cache-skip-steps", + type=str, + default=None, + help="Config-level KV-cache quantization skip-step selector, e.g. '0-9,20,25-30'.", + ) + parser.add_argument( + "--kv-cache-skip-layers", + type=str, + default=None, + help="Config-level KV-cache quantization skip-layer selector, e.g. '0,1,4-8'.", + ) parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).") parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.") parser.add_argument( @@ -309,6 +327,9 @@ def main(): vae_use_tiling=args.vae_use_tiling, boundary_ratio=args.boundary_ratio, flow_shift=args.flow_shift, + kv_cache_dtype=args.kv_cache_dtype, + kv_cache_skip_steps=args.kv_cache_skip_steps, + kv_cache_skip_layers=args.kv_cache_skip_layers, enable_cpu_offload=args.enable_cpu_offload, parallel_config=parallel_config, enforce_eager=args.enforce_eager, @@ -330,6 +351,9 @@ def main(): print(f" Inference steps: {args.num_inference_steps}") print(f" Frames: {args.num_frames}") print(f" Solver: {args.sample_solver}") + print(f" kv_cache_dtype(config): {args.kv_cache_dtype}") + print(f" kv_cache_skip_steps(config): {args.kv_cache_skip_steps}") + print(f" kv_cache_skip_layers(config): {args.kv_cache_skip_layers}") print( f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size}," f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}" diff --git a/tests/platforms/npu/quant/test_kv_quant_npu.py b/tests/platforms/npu/quant/test_kv_quant_npu.py new file mode 100644 index 00000000000..0c828a6a3bd --- /dev/null +++ b/tests/platforms/npu/quant/test_kv_quant_npu.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for NPU FP8 KV quantization helpers. + +These tests load ``kv_quant_npu`` from its source file via ``importlib`` so +the test module itself does not ``import vllm_omni`` (which would pull +``patch`` → ``aenum``, vLLM, etc.). +""" + +from __future__ import annotations + +import importlib.util +import math +from pathlib import Path +from types import ModuleType, SimpleNamespace +from typing import Any + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu] + + +def _repo_root() -> Path: + """Resolve checkout root (parent of ``vllm_omni/``), not ``tests/``.""" + here = Path(__file__).resolve() + marker = Path("vllm_omni") / "platforms" / "npu" / "quant" / "kv_quant_npu.py" + for parent in here.parents: + if (parent / marker).is_file(): + return parent + msg = f"could not locate repo root (no {marker}) starting from {here}" + raise FileNotFoundError(msg) + + +def _load_kv_quant_npu() -> ModuleType: + path = _repo_root() / "vllm_omni" / "platforms" / "npu" / "quant" / "kv_quant_npu.py" + if not path.is_file(): + msg = f"kv_quant_npu source not found: {path}" + raise FileNotFoundError(msg) + name = "vllm_omni_test_kv_quant_npu_standalone" + spec = importlib.util.spec_from_file_location(name, path) + if spec is None or spec.loader is None: + msg = f"cannot load import spec for {path}" + raise RuntimeError(msg) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +kv_quant_npu = _load_kv_quant_npu() + + +def _npu_smoke_available() -> bool: + try: + import torch_npu # noqa: F401 + except ImportError: + return False + return bool(hasattr(torch, "npu") and torch.npu.is_available()) + + +npu_smoke = pytest.mark.skipif(not _npu_smoke_available(), reason="NPU device or torch_npu not available.") + + +def test_is_quantized_kv_cache() -> None: + assert kv_quant_npu.is_quantized_kv_cache("fp8") + assert not kv_quant_npu.is_quantized_kv_cache(None) + assert not kv_quant_npu.is_quantized_kv_cache("int8") + + +class TestKVQuantNPUUnit: + @pytest.fixture(autouse=True) + def clear_rot_cache(self): + kv_quant_npu._ROT_MATRIXS.clear() + + def test_get_rot_matrix_caches_by_device_dtype_and_head_dim(self) -> None: + calls = {"count": 0} + + class FakeQuaRotMode: + HADAMARD = "hadamard" + + def fake_create_rot(mode, head_dim, seed): + calls["count"] += 1 + assert mode == FakeQuaRotMode.HADAMARD + assert seed == 425500 + return torch.eye(head_dim, dtype=torch.float32) + + device = torch.device("cpu") + rot_1 = kv_quant_npu._get_rot_matrix(device, torch.float16, 8, FakeQuaRotMode, fake_create_rot) + rot_2 = kv_quant_npu._get_rot_matrix(device, torch.float16, 8, FakeQuaRotMode, fake_create_rot) + rot_3 = kv_quant_npu._get_rot_matrix(device, torch.bfloat16, 8, FakeQuaRotMode, fake_create_rot) + rot_4 = kv_quant_npu._get_rot_matrix(device, torch.float16, 16, FakeQuaRotMode, fake_create_rot) + + assert calls["count"] == 3 + assert rot_1 is rot_2 + assert rot_3.dtype == torch.bfloat16 + assert rot_4.shape == (16, 16) + + @pytest.fixture + def fake_quant_ops(self, monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]: + captured: dict[str, Any] = { + "fa_calls": [], + "npu_kwargs": None, + "out_shape": None, + } + + class FakeTorchNPU: + float8_e4m3fn = "fp8_marker" + + @staticmethod + def npu_fused_infer_attention_score_v2(q, k, v, **kwargs): + del q, k, v + captured["npu_kwargs"] = kwargs + out_shape = captured["out_shape"] + return (torch.ones(out_shape, dtype=torch.float32),) + + def fake_fa_block_quant_preprocess(x, block_size, dst_type, layout): + captured["fa_calls"].append( + { + "block_size": block_size, + "layout": layout, + "dst_type": dst_type, + "shape": tuple(x.shape), + } + ) + scale = torch.full((1,), float(block_size), dtype=torch.float32) + return x, scale + + fake_qua_rot_mode = SimpleNamespace(HADAMARD="hadamard") + + def fake_create_rot(mode, head_dim, seed): + assert mode == "hadamard" + assert seed == 425500 + return torch.eye(head_dim, dtype=torch.float32) + + monkeypatch.setattr( + kv_quant_npu, + "_load_quant_ops", + lambda: (FakeTorchNPU, fake_fa_block_quant_preprocess, fake_qua_rot_mode, fake_create_rot), + ) + + return captured + + @staticmethod + def _make_qkv(shape: tuple[int, int, int, int]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + query = torch.randn(*shape, dtype=torch.float32) + key = torch.randn(*shape, dtype=torch.float32) + value = torch.randn(*shape, dtype=torch.float32) + return query, key, value + + @pytest.mark.parametrize( + "layout,input_shape,out_shape,softmax_scale,expected_scale", + [ + ("BNSD", (2, 3, 4, 8), (2, 3, 6, 8), None, 1.0 / math.sqrt(8)), + ("BSND", (2, 4, 3, 8), (2, 6, 3, 8), 0.125, 0.125), + ], + ) + def test_fp8_rotate_quant_fa_layouts_scale_and_crop( + self, + fake_quant_ops: dict[str, Any], + layout: str, + input_shape: tuple[int, int, int, int], + out_shape: tuple[int, int, int, int], + softmax_scale: float | None, + expected_scale: float, + ) -> None: + query, key, value = self._make_qkv(input_shape) + fake_quant_ops["out_shape"] = out_shape + + out = kv_quant_npu.fp8_rotate_quant_fa(query, key, value, layout=layout, softmax_scale=softmax_scale) + + assert out.shape == query.shape + assert out.dtype == query.dtype + assert fake_quant_ops["npu_kwargs"]["input_layout"] == layout + # BNSD: shape[1]==heads, BSND: shape[2]==heads. + expected_heads = input_shape[1] if layout == "BNSD" else input_shape[2] + assert fake_quant_ops["npu_kwargs"]["num_query_heads"] == expected_heads + assert fake_quant_ops["npu_kwargs"]["softmax_scale"] == pytest.approx(expected_scale) + assert [call["block_size"] for call in fake_quant_ops["fa_calls"]] == [128, 256, 256] + + def test_fp8_rotate_quant_fa_invalid_layout_raises(self, fake_quant_ops) -> None: + query = torch.randn(1, 2, 3, 4, dtype=torch.float32) + key = torch.randn(1, 2, 3, 4, dtype=torch.float32) + value = torch.randn(1, 2, 3, 4, dtype=torch.float32) + fake_quant_ops["out_shape"] = (1, 2, 3, 4) + + with pytest.raises(ValueError, match="unsupported layout"): + kv_quant_npu.fp8_rotate_quant_fa(query, key, value, layout="INVALID") + + +@npu_smoke +class TestKVQuantNPUSmoke: + """Smoke tests using real torch_npu/mindiesd stack, only on NPU.""" + + def test_fp8_rotate_quant_fa_real_npu_shape_contract(self): + try: + kv_quant_npu._load_quant_ops.cache_clear() + kv_quant_npu._load_quant_ops() + except ImportError: + pytest.skip("NPU quant dependencies are not fully installed.") + + query = torch.randn(1, 2, 4, 64, dtype=torch.float16, device="npu") + key = torch.randn(1, 2, 4, 64, dtype=torch.float16, device="npu") + value = torch.randn(1, 2, 4, 64, dtype=torch.float16, device="npu") + + out = kv_quant_npu.fp8_rotate_quant_fa(query, key, value, layout="BNSD") + assert out.shape == query.shape + assert out.dtype == query.dtype diff --git a/vllm_omni/diffusion/attention/backends/abstract.py b/vllm_omni/diffusion/attention/backends/abstract.py index f702dc65028..cd0408f3ce8 100644 --- a/vllm_omni/diffusion/attention/backends/abstract.py +++ b/vllm_omni/diffusion/attention/backends/abstract.py @@ -67,6 +67,10 @@ class AttentionMetadata: extra: dict[str, Any] = field(default_factory=dict) # Opaque backend-specific per-forward parameters (e.g. block masks, KV indices). # Backends MUST silently ignore unknown keys. + # + # Well-known optional keys (convention, not required on all forwards): + # "kv_cache_dtype": str | None — quantized KV dtype (e.g. "fp8"); backends + # decide whether/how to apply. # Piecewise attention metadata (mixed causal/full masks). # full_attn_spans: per-sample [start, end) spans in global coordinates using full attention. @@ -77,6 +81,14 @@ class AttentionMetadata: class AttentionImpl(ABC, Generic[T]): + # Per-platform kv_cache_dtype support. Maps OmniPlatformEnum value + # (e.g. "cuda", "npu") to the set of quantized dtypes that platform + # handles. + # + # To add FP8 support for a new platform in a subclass: + # _supported_kv_cache_dtypes = {"cuda": {"fp8"}, "npu": {"fp8"}} + _supported_kv_cache_dtypes: dict[str, set[str]] = {} + @abstractmethod def __init__( self, @@ -92,6 +104,12 @@ def __init__( ) -> None: raise NotImplementedError + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: str | None, platform_key: str) -> bool: + if kv_cache_dtype is None: + return True + return kv_cache_dtype in cls._supported_kv_cache_dtypes.get(platform_key, set()) + def forward( self, query: torch.Tensor, diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py index a612546942a..cf8d4224fc8 100644 --- a/vllm_omni/diffusion/attention/backends/flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -39,6 +39,20 @@ def get_impl_cls() -> type["FlashAttentionImpl"]: class FlashAttentionImpl(AttentionImpl): + # Per-platform FP8 KV quantization support. + # To enable FP8 on a new platform, add its OmniPlatformEnum value here + # and handle kv_cache_dtype in the corresponding forward_{platform}(). + # + # TODO(quant-backend): The FP8 quant path currently lives inside + # FlashAttentionImpl gated by ``attn_metadata.extra["kv_cache_dtype"]``. + # Eventually extract it into a dedicated FlashAttentionQuantBackend so + # backend selection (not metadata) decides quant. Until then, model + # authors can opt a specific Attention layer out via + # ``Attention(disable_kv_quant=True)``. + _supported_kv_cache_dtypes = { + "npu": {"fp8"}, + } + def __init__( self, num_heads: int, @@ -250,6 +264,39 @@ def forward_npu( attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: """NPU attention implementation using mindiesd.""" + + kv_cache_dtype = attn_metadata.extra.get("kv_cache_dtype") if attn_metadata else None + if kv_cache_dtype is not None: + return self.forward_fa_quant_npu(query, key, value, attn_metadata) + return self.forward_fa_npu(query, key, value, attn_metadata) + + def forward_fa_quant_npu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: + from vllm_omni.platforms.npu.quant.kv_quant_npu import fp8_rotate_quant_fa + + layout = self.qkv_layout or "BNSD" + # Models pass (B, S, H, D); NPU fused op expects (B, N, S, D). + out = fp8_rotate_quant_fa( + query.transpose(1, 2), + key.transpose(1, 2), + value.transpose(1, 2), + layout=layout, + softmax_scale=self.softmax_scale, + ) + return out.transpose(1, 2) + + def forward_fa_npu( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata = None, + ) -> torch.Tensor: try: from mindiesd import attention_forward except ImportError: @@ -259,10 +306,9 @@ def forward_npu( "For installation details, see https://gitcode.com/Ascend/MindIE-SD" "Otherwise, use SDPA backend by setting DIFFUSION_ATTENTION_BACKEND=TORCH_SDPA" ) - attention_mask = attn_metadata.attn_mask if attn_metadata else None layout = self.qkv_layout or "BNSD" - output = attention_forward( + return attention_forward( query, key, value, @@ -271,4 +317,3 @@ def forward_npu( op_type="fused_attn_score", layout=layout, ) - return output diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index bbf3481fe6a..4ab139d4bc2 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -7,9 +7,12 @@ # https://github.com/feifeibear/long-context-attention/blob/main/yunchang/attention/layer.py +from dataclasses import replace + import torch import torch.nn as nn from vllm.logger import init_logger +from vllm.model_executor.models.utils import extract_layer_index from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend @@ -20,10 +23,20 @@ from vllm_omni.diffusion.config import get_current_diffusion_config_or_none from vllm_omni.diffusion.distributed.parallel_state import get_sp_group from vllm_omni.diffusion.forward_context import get_forward_context, is_forward_context_available +from vllm_omni.platforms import current_omni_platform logger = init_logger(__name__) +def _try_extract_layer_index(prefix: str) -> int | None: + if not prefix: + return None + try: + return extract_layer_index(prefix) + except (AssertionError, ValueError): + return None + + class Attention(nn.Module): def __init__( self, @@ -43,6 +56,11 @@ def __init__( gather_idx: int = 1, use_sync: bool = False, skip_sequence_parallel: bool = False, + # Opt-out for KV-cache quantization at this specific attention layer. + # Set by the model author when quant is known to degrade quality or + # perf for this layer (e.g. Wan2.2 cross-attn has short sequences and + # block-FP8 quant offers no win). Default False = follow global config. + disable_kv_quant: bool = False, ): super().__init__() @@ -126,6 +144,15 @@ def __init__( # Fallback strategy when SP is not active (outside sharded regions) self._no_parallel_strategy = NoParallelAttention() + self.layer_idx: int | None = _try_extract_layer_index(prefix) + + self._kv_cache_dtype: str | None = None + self._kv_cache_skip_steps: set[int] | None = None + self._kv_cache_skip_layers: set[int] | None = None + # Per-layer opt-out from KV-cache quantization (set by model author). + self._disable_kv_quant: bool = disable_kv_quant + self._init_kv_cache_quantization(config) + def _get_active_parallel_strategy(self): """Get the parallel strategy based on current SP active state. @@ -141,12 +168,64 @@ def _get_active_parallel_strategy(self): return self._no_parallel_strategy return self.parallel_strategy + def _init_kv_cache_quantization(self, config) -> None: + if config is None: + return + dtype = config.kv_cache_dtype + if dtype: + if config.parallel_config.ring_degree > 1: + raise ValueError( + "KV quantization is not compatible with ring attention " + "(ring_degree > 1). Ring kernels do not propagate quantization descale " + "factors. Use Ulysses SP instead." + ) + platform_key = current_omni_platform.device_name + if not self.attention.supports_kv_cache_dtype(dtype, platform_key): + logger.warning_once( + "Attention backend %s does not support kv_cache_dtype='%s' on %s. " + "KV quantization will be disabled.", + self.attn_backend.get_name(), + dtype, + platform_key, + ) + dtype = None + self._kv_cache_dtype = dtype + self._kv_cache_skip_steps = config.kv_cache_skip_step_indices + self._kv_cache_skip_layers = config.kv_cache_skip_layer_indices + + def _should_apply_kv_cache_quant(self) -> bool: + skip_steps = self._kv_cache_skip_steps + skip_layers = self._kv_cache_skip_layers + if skip_steps is not None: + step_idx = get_forward_context().denoise_step_idx if is_forward_context_available() else None + if step_idx is not None and step_idx in skip_steps: + return False + if skip_layers is not None: + if self.layer_idx is not None and self.layer_idx in skip_layers: + return False + return True + + def _with_kv_cache_dtype(self, attn_metadata: AttentionMetadata | None) -> AttentionMetadata | None: + kv_cache_dtype = self._kv_cache_dtype + if kv_cache_dtype is None or self._disable_kv_quant or not self._should_apply_kv_cache_quant(): + if attn_metadata is None or "kv_cache_dtype" not in attn_metadata.extra: + return attn_metadata + extra = dict(attn_metadata.extra) + extra.pop("kv_cache_dtype", None) + return replace(attn_metadata, extra=extra) + + if attn_metadata is None: + return AttentionMetadata(extra={"kv_cache_dtype": kv_cache_dtype}) + extra = dict(attn_metadata.extra) + extra["kv_cache_dtype"] = kv_cache_dtype + return replace(attn_metadata, extra=extra) + def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - attn_metadata: AttentionMetadata = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: # Get the appropriate parallel strategy based on SP active state strategy = self._get_active_parallel_strategy() @@ -156,6 +235,8 @@ def forward( # For Ring: Concat joint_q query, key, value, attn_metadata, ctx = strategy.pre_attention(query, key, value, attn_metadata) + attn_metadata = self._with_kv_cache_dtype(attn_metadata) + # 2. Kernel Execution (Computation) if self.use_ring and strategy is not self._no_parallel_strategy: out = self._run_ring_attention(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index bea0b72957e..72776dcc30e 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -32,6 +32,54 @@ logger = init_logger(__name__) +def parse_kv_cache_skip_selector( + selector: str | list[int] | tuple[int, ...] | set[int] | None, +) -> set[int] | None: + """Parse a non-negative index selector such as "0-9,20,25-30".""" + if selector is None: + return None + if isinstance(selector, set): + values = selector + elif isinstance(selector, (list, tuple)): + values = set(selector) + elif isinstance(selector, str): + text = selector.strip() + if not text: + return None + values: set[int] = set() + for chunk in text.split(","): + token = chunk.strip() + if not token: + continue + if "-" in token: + start_str, end_str = token.split("-", 1) + try: + start = int(start_str.strip()) + end = int(end_str.strip()) + except ValueError as exc: + raise ValueError(f"Invalid range token '{token}' in selector '{selector}'.") from exc + if start < 0 or end < 0 or start > end: + raise ValueError(f"Invalid range token '{token}' in selector '{selector}'.") + values.update(range(start, end + 1)) + else: + try: + index = int(token) + except ValueError as exc: + raise ValueError(f"Invalid index token '{token}' in selector '{selector}'.") from exc + if index < 0: + raise ValueError(f"Negative index '{index}' is not allowed in selector '{selector}'.") + values.add(index) + else: + raise TypeError(f"Unsupported selector type: {type(selector)!r}") + + for idx in values: + if not isinstance(idx, int): + raise TypeError(f"Selector index must be int, got {type(idx)!r}") + if idx < 0: + raise ValueError("Selector indices must be non-negative.") + return values + + @config @dataclass class DiffusionParallelConfig: @@ -519,6 +567,19 @@ class OmniDiffusionConfig: # has already resolved to vLLM's ModelOpt FP8 linear method. force_cutlass_fp8: bool = False + # KV cache dtype for attention. Aligned with upstream vLLM's --kv-cache-dtype. + # None = native dtype (no quantization). + # "fp8" = dynamic FP8 (float8_e4m3fn) quantization per forward pass. + # On Hopper+FA3: native FP8 attention (memory + compute savings). + # On other backends: no benefit, backends skip quantization. + kv_cache_dtype: str | None = None + # Optional skip selectors for KV-cache quantization. Format: "0-9,20,25-30". + # Listed steps/layers skip quantization; others keep quantized execution. + kv_cache_skip_steps: str | None = None + kv_cache_skip_layers: str | None = None + kv_cache_skip_step_indices: set[int] | None = None + kv_cache_skip_layer_indices: set[int] | None = None + # Diffusion pipeline Profiling config enable_diffusion_pipeline_profiler: bool = False @@ -527,6 +588,7 @@ class OmniDiffusionConfig: # sleep mode enable_sleep_mode: bool = False + # Maximum number of sequences to generate in a batch max_num_seqs: int = 1 @@ -656,6 +718,8 @@ def __post_init__(self): # Match vLLM's config flow: parse entrypoint shorthands before the # config object is built, and keep a single runtime truth source. self.diffusion_attention_config = build_attention_config(self.diffusion_attention_config) + self.kv_cache_skip_step_indices = parse_kv_cache_skip_selector(self.kv_cache_skip_steps) + self.kv_cache_skip_layer_indices = parse_kv_cache_skip_selector(self.kv_cache_skip_layers) if self.max_cpu_loras is None: self.max_cpu_loras = 1 diff --git a/vllm_omni/diffusion/forward_context.py b/vllm_omni/diffusion/forward_context.py index f6df4730aaa..35c93c79933 100644 --- a/vllm_omni/diffusion/forward_context.py +++ b/vllm_omni/diffusion/forward_context.py @@ -22,6 +22,7 @@ class ForwardContext: omni_diffusion_config: OmniDiffusionConfig | None = None attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None = None split_text_embed_in_sp: bool = False + denoise_step_idx: int | None = None # whether to split the text embed in sequence parallel, if True, the text embed will be split in sequence parallel # Sequence Parallel padding support @@ -103,12 +104,14 @@ def create_forward_context( omni_diffusion_config: OmniDiffusionConfig | None = None, attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None = None, split_text_embed_in_sp: bool = False, + denoise_step_idx: int | None = None, ): return ForwardContext( vllm_config=vllm_config, omni_diffusion_config=omni_diffusion_config, attn_metadata=attn_metadata, split_text_embed_in_sp=split_text_embed_in_sp, + denoise_step_idx=denoise_step_idx, ) @@ -133,6 +136,7 @@ def set_forward_context( omni_diffusion_config: OmniDiffusionConfig | None = None, attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] | None = None, split_text_embed_in_sp: bool = False, + denoise_step_idx: int | None = None, ): """A context manager that stores the current forward context, can be attention metadata, split_text_embed_in_sp, etc. @@ -143,6 +147,7 @@ def set_forward_context( omni_diffusion_config=omni_diffusion_config, attn_metadata=attn_metadata, split_text_embed_in_sp=split_text_embed_in_sp, + denoise_step_idx=denoise_step_idx, ) # vLLM CustomOp dispatch (e.g. QKVParallelLinear) requires a global # vLLM config set via set_current_vllm_config(). @@ -160,3 +165,9 @@ def set_forward_context( vllm.ir.enable_torch_wrap(vllm_config.compilation_config.ir_enable_torch_wrap), ): yield + + +def set_forward_context_denoise_step_idx(step_idx: int | None) -> None: + """Set the current diffusion denoise step on the active ForwardContext.""" + if _forward_context is not None: + _forward_context.denoise_step_idx = step_idx diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py index da9424e6365..ff21997c9c4 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2.py @@ -21,6 +21,7 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import DistributedAutoencoderKLWan from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.forward_context import set_forward_context_denoise_step_idx from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin @@ -403,9 +404,12 @@ def diffuse( latent_condition: torch.Tensor | None = None, first_frame_mask: torch.Tensor | None = None, ) -> torch.Tensor: + if attention_kwargs is None: + attention_kwargs = {} with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: + for step_idx, t in enumerate(timesteps): self._current_timestep = t + set_forward_context_denoise_step_idx(step_idx) # Select model based on timestep and boundary_ratio # High noise stage (t >= boundary_timestep): use transformer diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py index d74c328206f..67fbdd3f039 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_i2v.py @@ -22,6 +22,7 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import DistributedAutoencoderKLWan from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.forward_context import set_forward_context_denoise_step_idx from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders from vllm_omni.diffusion.models.dmd2 import DMD2PipelineMixin @@ -282,8 +283,10 @@ def diffuse( condition: torch.Tensor, first_frame_mask: torch.Tensor, ) -> torch.Tensor: + if attention_kwargs is None: + attention_kwargs = {} with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: + for step_idx, t in enumerate(timesteps): self._current_timestep = t # Select model and guidance scale based on timestep @@ -293,6 +296,8 @@ def diffuse( current_model = self.transformer_2 current_guidance_scale = guidance_high + set_forward_context_denoise_step_idx(step_idx) + # Prepare latent input if self.expand_timesteps: # TI2V-5B style: blend condition with latents using mask @@ -308,6 +313,7 @@ def diffuse( timestep = t.expand(latents.shape[0]) do_true_cfg = current_guidance_scale > 1.0 and negative_prompt_embeds is not None + # Prepare kwargs for positive and negative predictions positive_kwargs = { "hidden_states": latent_model_input, "timestep": timestep, @@ -330,6 +336,7 @@ def diffuse( else: negative_kwargs = None + # Predict noise with automatic CFG parallel handling noise_pred = self.predict_noise_maybe_with_cfg( do_true_cfg=do_true_cfg, true_cfg_scale=current_guidance_scale, @@ -338,7 +345,9 @@ def diffuse( cfg_normalize=False, ) + # Compute the previous noisy sample x_t -> x_t-1 with automatic CFG sync latents = self.scheduler_step_maybe_with_cfg(noise_pred, t, latents, do_true_cfg) + pbar.update() return latents diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py index d1650cd9581..eec5c9cbfce 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_ti2v.py @@ -33,6 +33,7 @@ from vllm_omni.diffusion.distributed.autoencoders.autoencoder_kl_wan import OmniAutoencoderKLWan from vllm_omni.diffusion.distributed.cfg_parallel import CFGParallelMixin from vllm_omni.diffusion.distributed.utils import get_local_device +from vllm_omni.diffusion.forward_context import set_forward_context_denoise_step_idx from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader from vllm_omni.diffusion.model_loader.hub_prefetch import prefetch_subfolders from vllm_omni.diffusion.models.interface import SupportImageInput @@ -246,9 +247,12 @@ def diffuse( latent_condition: torch.Tensor | None = None, first_frame_mask: torch.Tensor | None = None, ) -> torch.Tensor: + if attention_kwargs is None: + attention_kwargs = {} with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: + for step_idx, t in enumerate(timesteps): self._current_timestep = t + set_forward_context_denoise_step_idx(step_idx) # Prepare latent input if latent_condition is not None: diff --git a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py index 0458f88597e..9de620400a5 100644 --- a/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py +++ b/vllm_omni/diffusion/models/wan2_2/pipeline_wan2_2_vace.py @@ -24,6 +24,7 @@ from vllm.model_executor.models.utils import AutoWeightsLoader from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig +from vllm_omni.diffusion.forward_context import set_forward_context_denoise_step_idx from vllm_omni.diffusion.models.interface import SupportImageInput from vllm_omni.diffusion.models.wan2_2.pipeline_wan2_2 import ( Wan22Pipeline, @@ -188,9 +189,13 @@ def diffuse( vace_context: torch.Tensor | None, vace_context_scale: float, ) -> torch.Tensor: + if attention_kwargs is None: + attention_kwargs = {} with self.progress_bar(total=len(timesteps)) as pbar: - for t in timesteps: + for step_idx, t in enumerate(timesteps): self._current_timestep = t + set_forward_context_denoise_step_idx(step_idx) + latent_model_input = latents.to(dtype) timestep = t.expand(latents.shape[0]) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py index e80b116ab24..2015d3272bb 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_transformer.py @@ -329,6 +329,7 @@ def __init__( head_dim: int, eps: float = 1e-5, dropout: float = 0.0, + prefix: str = "", ): super().__init__() @@ -374,13 +375,14 @@ def __init__( softmax_scale=1.0 / (head_dim**0.5), causal=False, role="self", + prefix=prefix, ) def forward( self, hidden_states: torch.Tensor, rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None, - attn_mask: torch.Tensor | None = None, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: # Fused QKV projection qkv, _ = self.to_qkv(hidden_states) @@ -405,11 +407,6 @@ def forward( query = self.rotary_embedding(query, freqs_cos, freqs_sin) key = self.rotary_embedding(key, freqs_cos, freqs_sin) - # Create attention metadata if mask is provided - attn_metadata = None - if attn_mask is not None: - attn_metadata = AttentionMetadata(attn_mask=attn_mask) - # Compute attention using unified attention layer hidden_states = self.attn(query, key, value, attn_metadata) hidden_states = hidden_states.flatten(2, 3) @@ -436,6 +433,7 @@ def __init__( eps: float = 1e-5, dropout: float = 0.0, added_kv_proj_dim: int | None = None, + prefix: str = "", ): super().__init__() @@ -528,13 +526,19 @@ def __init__( causal=False, role="cross", qkv_layout="BSND", + prefix=prefix, skip_sequence_parallel=True, + # Wan2.2 cross-attn operates on short text-encoder sequences; per-block + # FP8 quant offers no perf win and degrades quality. Opt out until a + # dedicated quant backend handles this case. + disable_kv_quant=True, ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata | None = None, ) -> torch.Tensor: # Handle I2V case where encoder_hidden_states contains both image and text encoder_hidden_states_img = None @@ -568,12 +572,12 @@ def forward( key_img = key_img.unflatten(2, (self.num_heads, self.head_dim)) value_img = value_img.unflatten(2, (self.num_heads, self.head_dim)) - hidden_states_img = self.attn(query, key_img, value_img) + hidden_states_img = self.attn(query, key_img, value_img, attn_metadata) hidden_states_img = hidden_states_img.flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) # Main cross-attention using unified attention layer - hidden_states = self.attn(query, key, value) + hidden_states = self.attn(query, key, value, attn_metadata) hidden_states = hidden_states.flatten(2, 3) hidden_states = hidden_states.type_as(query) @@ -602,6 +606,7 @@ def __init__( eps: float = 1e-6, added_kv_proj_dim: int | None = None, cross_attn_norm: bool = False, + prefix: str = "", ): super().__init__() @@ -614,6 +619,7 @@ def __init__( num_heads=num_heads, head_dim=head_dim, eps=eps, + prefix=f"{prefix}.attn1", ) # 2. Cross-attention @@ -623,6 +629,7 @@ def __init__( head_dim=head_dim, eps=eps, added_kv_proj_dim=added_kv_proj_dim, + prefix=f"{prefix}.attn2", ) self.norm2 = LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -660,12 +667,13 @@ def forward( # 1. Self-attention norm_hidden_states = self.norm1(hidden_states, scale_msa, shift_msa).type_as(hidden_states) - attn_output = self.attn1(norm_hidden_states, rotary_emb, hidden_states_mask) + self_attn_metadata = AttentionMetadata(attn_mask=hidden_states_mask) + attn_output = self.attn1(norm_hidden_states, rotary_emb, self_attn_metadata) hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states) # 2. Cross-attention norm_hidden_states = self.norm2(hidden_states).type_as(hidden_states) - attn_output = self.attn2(norm_hidden_states, encoder_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None) hidden_states = hidden_states + attn_output # 3. Feed-forward @@ -834,8 +842,16 @@ def __init__( # 3. Transformer blocks self.blocks = nn.ModuleList( [ - WanTransformerBlock(inner_dim, ffn_dim, num_attention_heads, eps, added_kv_proj_dim, cross_attn_norm) - for _ in range(num_layers) + WanTransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + eps, + added_kv_proj_dim, + cross_attn_norm, + prefix=f"blocks.{layer_idx}", + ) + for layer_idx in range(num_layers) ] ) @@ -929,7 +945,13 @@ def forward( # Transformer blocks for block in self.blocks: - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, hidden_states_mask) + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + hidden_states_mask, + ) # Output norm, projection & unpatchify shift, scale = self.output_scale_shift_prepare(temb) diff --git a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py index 5060f1904f2..995b16c4850 100644 --- a/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py +++ b/vllm_omni/diffusion/models/wan2_2/wan2_2_vace_transformer.py @@ -33,8 +33,17 @@ def __init__( added_kv_proj_dim: int | None = None, cross_attn_norm: bool = False, block_id: int = 0, + prefix: str = "", ): - super().__init__(dim, ffn_dim, num_heads, eps, added_kv_proj_dim, cross_attn_norm) + super().__init__( + dim, + ffn_dim, + num_heads, + eps, + added_kv_proj_dim, + cross_attn_norm, + prefix=prefix, + ) self.proj_in = nn.Linear(dim, dim) if block_id == 0 else None self.proj_out = nn.Linear(dim, dim) @@ -118,6 +127,7 @@ def __init__( self.config.added_kv_proj_dim, self.config.cross_attn_norm, block_id=i, + prefix=f"vace_blocks.{i}", ) for i in range(len(vace_layers)) ] @@ -220,7 +230,7 @@ def forward( full_seq_len = hidden_states.shape[1] * sp_size control_hidden_states = self.embed_vace_context(vace_context.to(hidden_states.dtype), full_seq_len, sp_size) vace_hints = [] - for block in self.vace_blocks: + for i, block in enumerate(self.vace_blocks): conditioning_states, control_hidden_states = block( hidden_states, encoder_hidden_states, @@ -237,7 +247,13 @@ def forward( # Transformer blocks with VACE hint application for i, block in enumerate(self.blocks): - hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, hidden_states_mask) + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + hidden_states_mask, + ) if vace_hints is not None and self.vace_layers_mapping is not None and i in self.vace_layers_mapping: vace_idx = self.vace_layers_mapping[i] hidden_states = hidden_states + vace_hints[vace_idx] * vace_context_scale[vace_idx] diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index b2dc839c976..bcb33e417a1 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -1477,6 +1477,9 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4), "quantization": kwargs.get("quantization", None), + "kv_cache_dtype": kwargs.get("kv_cache_dtype", None), + "kv_cache_skip_steps": kwargs.get("kv_cache_skip_steps", None), + "kv_cache_skip_layers": kwargs.get("kv_cache_skip_layers", None), **({"diffusion_attention_config": attention_config} if attention_config is not None else {}), "force_cutlass_fp8": bool(kwargs.get("force_cutlass_fp8", False)), "enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False), @@ -1651,6 +1654,24 @@ def _resolve_stage_configs(self, model: str, kwargs: dict[str, Any]) -> tuple[st if quantization is not None: if not hasattr(cfg.engine_args, "quantization") or cfg.engine_args.quantization is None: cfg.engine_args.quantization = quantization + kv_cache_dtype = kwargs.get("kv_cache_dtype") + if kv_cache_dtype is not None: + if not hasattr(cfg.engine_args, "kv_cache_dtype") or cfg.engine_args.kv_cache_dtype is None: + cfg.engine_args.kv_cache_dtype = kv_cache_dtype + kv_cache_skip_steps = kwargs.get("kv_cache_skip_steps") + if kv_cache_skip_steps is not None: + if ( + not hasattr(cfg.engine_args, "kv_cache_skip_steps") + or cfg.engine_args.kv_cache_skip_steps is None + ): + cfg.engine_args.kv_cache_skip_steps = kv_cache_skip_steps + kv_cache_skip_layers = kwargs.get("kv_cache_skip_layers") + if kv_cache_skip_layers is not None: + if ( + not hasattr(cfg.engine_args, "kv_cache_skip_layers") + or cfg.engine_args.kv_cache_skip_layers is None + ): + cfg.engine_args.kv_cache_skip_layers = kv_cache_skip_layers except Exception as e: logger.warning("Failed to inject LoRA config for stage: %s", e) diff --git a/vllm_omni/entrypoints/cli/serve.py b/vllm_omni/entrypoints/cli/serve.py index b4293d59fd7..d714d1f53ff 100644 --- a/vllm_omni/entrypoints/cli/serve.py +++ b/vllm_omni/entrypoints/cli/serve.py @@ -472,6 +472,27 @@ def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgu default=None, help="Scheduler flow_shift for video models (e.g., 5.0 for 720p, 12.0 for 480p).", ) + # vLLM already registers --kv-cache-dtype for the serve parser. Keep + # this fallback only for older vLLM versions where the option is absent. + if "--kv-cache-dtype" not in serve_parser._option_string_actions: + omni_config_group.add_argument( + "--kv-cache-dtype", + type=str, + default=None, + help="Config-level KV cache dtype (e.g. fp8).", + ) + omni_config_group.add_argument( + "--kv-cache-skip-steps", + type=str, + default=None, + help="Config-level KV-cache quantization skip-step selector, e.g. '0-9,20,25-30'.", + ) + omni_config_group.add_argument( + "--kv-cache-skip-layers", + type=str, + default=None, + help="Config-level KV-cache quantization skip-layer selector, e.g. '0,1,4-8'.", + ) omni_config_group.add_argument( "--cfg-parallel-size", type=int, diff --git a/vllm_omni/platforms/npu/quant/__init__.py b/vllm_omni/platforms/npu/quant/__init__.py new file mode 100644 index 00000000000..208f01a7cb5 --- /dev/null +++ b/vllm_omni/platforms/npu/quant/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project diff --git a/vllm_omni/platforms/npu/quant/kv_quant_npu.py b/vllm_omni/platforms/npu/quant/kv_quant_npu.py new file mode 100644 index 00000000000..cb76aab5144 --- /dev/null +++ b/vllm_omni/platforms/npu/quant/kv_quant_npu.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FP8 quantization utilities for diffusion attention tensors. + +Provides per-tensor dynamic quantization of Q/K/V tensors to +float8_e4m3fn format. Designed for diffusion models where Q/K/V are +computed fresh each forward pass (no persistent KV cache). +""" + +from __future__ import annotations + +import math +import threading +from functools import lru_cache + +import torch + +# Hadamard rotation matrix for QuaRot-style preprocessing +# keyed by (device, dtype, head_dim) to avoid matmul dtype mismatch. +_ROT_MATRIXS: dict[tuple[torch.device, torch.dtype, int], torch.Tensor] = {} +_ROT_MATRIX_LOCK = threading.Lock() + +_FP8_KV_LABELS = frozenset({"fp8"}) + + +def is_quantized_kv_cache(kv_cache_dtype: str | None) -> bool: + """True if config requests FP8-style KV / QKV quantization for the NPU FA path.""" + return kv_cache_dtype in _FP8_KV_LABELS + + +@lru_cache(maxsize=1) +def _load_quant_ops(): + try: + import torch_npu + from mindiesd.layers.quant.block_quant import fa_block_quant_preprocess + from msmodelslim.processor.quarot.common.quarot_utils import QuaRotMode, create_rot + except ImportError as e: + raise ImportError( + "fp8_rotate_quant_fa requires torch_npu, MindIE-SD (mindiesd), and MSModelSlim. " + "See https://gitcode.com/Ascend/MindIE-SD and https://gitcode.com/Ascend/msmodelslim" + ) from e + return torch_npu, fa_block_quant_preprocess, QuaRotMode, create_rot + + +def _get_rot_matrix( + device: torch.device, + dtype: torch.dtype, + head_dim: int, + qua_rot_mode, + create_rot, +) -> torch.Tensor: + key = (device, dtype, head_dim) + with _ROT_MATRIX_LOCK: + rot = _ROT_MATRIXS.get(key) + if rot is None: + rot = create_rot(qua_rot_mode.HADAMARD, head_dim, seed=425500).to(device=device, dtype=dtype) + _ROT_MATRIXS[key] = rot + return rot + + +def fp8_rotate_quant_fa( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + *, + layout: str = "BNSD", + softmax_scale: float | None = None, +) -> torch.Tensor: + """Run NPU fused attention with dynamic FP8 Q/K/V and optional QuaRot preprocess. + + Args: + query: Query tensor in ``layout`` order (default BNSD: batch, heads, seq, dim). + key: Key tensor in ``layout`` order (default BNSD: batch, heads, seq, dim). + value: Value tensor in ``layout`` order (default BNSD: batch, heads, seq, dim). + layout: ``BNSD`` or ``BSND`` for ``npu_fused_infer_attention_score_v2``. + softmax_scale: If None, uses ``1 / sqrt(head_dim)``. + + Returns: + Attention output in the same layout as inputs. + """ + torch_npu, fa_block_quant_preprocess, qua_rot_mode, create_rot = _load_quant_ops() + + out_dtype = query.dtype + device = query.device + + if layout == "BNSD": + _, n, s, d = query.shape + elif layout == "BSND": + _, s, n, d = query.shape + else: + raise ValueError(f"fp8_rotate_quant_fa: unsupported layout {layout!r}, expected BNSD or BSND") + + rot = _get_rot_matrix(device, query.dtype, d, qua_rot_mode, create_rot) + q_f = torch.matmul(query, rot) + k_f = torch.matmul(key, rot) + + q, q_scale = fa_block_quant_preprocess(q_f, block_size=128, dst_type=torch_npu.float8_e4m3fn, layout=layout) + k, k_scale = fa_block_quant_preprocess(k_f, block_size=256, dst_type=torch_npu.float8_e4m3fn, layout=layout) + v, v_scale = fa_block_quant_preprocess(value, block_size=256, dst_type=torch_npu.float8_e4m3fn, layout=layout) + + scale = softmax_scale if softmax_scale is not None else 1.0 / math.sqrt(d) + + out = torch_npu.npu_fused_infer_attention_score_v2( + q, + k, + v, + input_layout=layout, + num_query_heads=n, + softmax_scale=scale, + pre_tokens=2147483647, # INT32_MAX: no left-context truncation. + next_tokens=2147483647, # INT32_MAX: no right-context truncation. + query_quant_mode=7, # NPU mode id for block FP8 dequant path. + key_quant_mode=7, # Same quant mode as query branch. + value_quant_mode=7, # Same quant mode as key/query branches. + dequant_scale_query=q_scale, + dequant_scale_key=k_scale, + dequant_scale_value=v_scale, + out_dtype=out_dtype, + )[0] + + if out.shape[2] != s: + if layout == "BNSD": + out = out[:, :, :s, :] + elif layout == "BSND": + out = out[:, :s, :, :] + + return out