Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
fbd422d
support online FP8 quantization for FA on NPU
lyj-jjj Apr 24, 2026
859d3f3
cleancode
lyj-jjj Apr 24, 2026
55c1314
cleancode
lyj-jjj Apr 24, 2026
f402653
resolve conflicts
lyj-jjj Apr 25, 2026
1620dd2
crossattention layout="BSDN"
lyj-jjj Apr 25, 2026
3340725
adapt step and layer fallback
lyj-jjj Apr 25, 2026
b636616
resolve conflicts
lyj-jjj Apr 25, 2026
ffc84ef
cleancode
lyj-jjj Apr 25, 2026
a2eda19
support other case
lyj-jjj Apr 25, 2026
0fac348
support wan-vace
lyj-jjj Apr 27, 2026
3d2ebf9
add UT
lyj-jjj May 7, 2026
2dd169d
cleancode
lyj-jjj May 8, 2026
8d81581
fa-fp8-quant-npu
lyj-jjj May 9, 2026
47a46e0
fa-fp8-quant-npu
lyj-jjj May 9, 2026
8a7512e
resolve conflicts
lyj-jjj May 11, 2026
74747f6
Merge upstream/main: resolve flash_attn, data.py, async_omni_engine c…
lyj-jjj May 11, 2026
592e877
fa-fp8-npu
lyj-jjj May 11, 2026
6356d0b
fix-bug
lyj-jjj May 11, 2026
41ab49d
fix-bug
lyj-jjj May 11, 2026
8ef8d19
fix-bug
lyj-jjj May 12, 2026
933491f
fa-fp8-quant-npu-ut
lyj-jjj May 12, 2026
036c251
fa-fp8-quant-npu-ut
lyj-jjj May 12, 2026
5f863f1
cleancode
lyj-jjj May 12, 2026
25654e3
fix ci error
lyj-jjj May 12, 2026
44bfc45
fix ci error
lyj-jjj May 12, 2026
7cbaae6
add doc
lyj-jjj May 12, 2026
9f707f6
fix ci error
lyj-jjj May 12, 2026
63b4122
fix ci error
lyj-jjj May 12, 2026
2e7d00c
fix ci error
lyj-jjj May 12, 2026
c5ff07c
fix ci error
lyj-jjj May 12, 2026
c8b9293
Add disable_kv_quant field
gcanlin May 12, 2026
b58151b
trying reverting test
gcanlin May 12, 2026
98d6ced
Refactor and cleanup
gcanlin May 12, 2026
d291ba4
Merge branch 'main' into pr-2640
gcanlin May 12, 2026
abe7bb6
Merge branch 'main' into main
gcanlin May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ nav:
- Quantization:
- Overview: user_guide/quantization/overview.md
- Online Quantization: user_guide/quantization/online.md
- Quantized KV Cache: user_guide/quantization/quantized_kvcache.md
- FP8 W8A8: user_guide/quantization/fp8.md
- Int8 W8A8: user_guide/quantization/int8.md
- ModelOpt: user_guide/quantization/modelopt.md
Expand Down
3 changes: 2 additions & 1 deletion docs/user_guide/quantization/overview.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Quantization
# Quantization

vLLM-Omni exposes quantization through the unified `quantization_config`
path. The same configuration entrypoint is used across diffusion-only models,
Expand All @@ -10,6 +10,7 @@ type has a different quantization scope.
| Mode | Guide | Description | Methods |
|------|-------|-------------|---------|
| Online quantization | [Online Quantization](online.md) | vLLM-Omni computes quantized weights and scales while loading the model. | FP8 W8A8, Int8 W8A8 |
| Runtime attention quantization | [Quantized KV Cache](quantized_kvcache.md) | vLLM-Omni dynamically quantizes eligible diffusion Flash Attention tensors during inference. | FP8 FA |
| Pre-quantized checkpoints | Method-specific guides | The checkpoint or an offline quantizer provides quantized weights and scales before serving. | ModelOpt, GGUF, AutoRound, msModelSlim, serialized Int8 |

## Hardware Support
Expand Down
115 changes: 115 additions & 0 deletions docs/user_guide/quantization/quantized_kvcache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Quantized KV Cache

## Overview

In DiT-based image and video generation, Flash Attention can take a large share
of denoising time, especially for high-resolution or long-frame workloads.
vLLM-Omni supports online FP8 quantization for eligible diffusion Flash
Attention (FA) to reduce FA latency while keeping model weights in their
original dtype.

This feature is configured through `kv_cache_dtype`, matching the option name
used by vLLM's language-model KV-cache quantization. In vLLM-Omni diffusion
pipelines, however, it is a runtime FA path: Q/K/V tensors are dynamically
quantized before the attention operator. It does not quantize model weights and
is separate from [FP8 W8A8](fp8.md), [Int8 W8A8](int8.md), or pre-quantized
checkpoint formats.

If `kv_cache_dtype` is not set, behavior is unchanged and attention runs in the
native dtype.

## Hardware Support

| Device | FP8 FA |
|--------|--------|
| Ascend NPU | ✅ |
| NVIDIA GPU | ❌ |
| AMD ROCm | ❌ |
| Intel XPU | ❌ |

Legend: `✅` supported, `❌` unsupported.

FP8 FA is currently implemented only for the NPU Flash Attention backend. Other
backends do not support `kv_cache_dtype="fp8"` for diffusion attention and fall
back to native dtype execution.

## Model Type Support

### Diffusion Model

| Model | Scope | Status | Notes |
|-------|-------|--------|-------|
| Wan2.2 | Eligible DiT full-attention FA on Ascend NPU | Tested | Compare quality and latency against a BF16 baseline before production use |
| Other diffusion models | Eligible DiT full-attention FA on Ascend NPU | Not tested | You can try `kv_cache_dtype="fp8"`; tune `kv_cache_skip_steps` and `kv_cache_skip_layers` when higher precision is needed |

### Multi-Stage Omni/TTS Model (Qwen3-Omni, Qwen3-TTS)

Not tested for FP8 FA. Treat any use as experimental unless a model-specific
guide documents support.

### Multi-Stage Diffusion Model (BAGEL, GLM-Image)

Not tested. If the diffusion stage uses the same NPU Flash Attention backend,
`kv_cache_dtype` may apply in theory; validate quality and latency for each
stage and model.

## Configuration

Offline diffusion example:

```bash
python examples/offline_inference/image_to_video/image_to_video.py \
--model <your-wan2.2-model> \
--prompt "A cat sitting on a surfboard at the beach" \
--height 1280 \
--width 720 \
--num-frames 61 \
--num-inference-steps 4 \
--ulysses-degree 4 \
--vae-patch-parallel-size 4 \
--kv-cache-dtype fp8 \
--kv-cache-skip-steps "0,1" \
--kv-cache-skip-layers "0-2"
```

Online serving:

```bash
vllm serve <your-model> --omni --kv-cache-dtype fp8
```

Stage config:

```yaml
stage_args:
- stage_id: 0
stage_type: diffusion
engine_args:
model_stage: dit
kv_cache_dtype: "fp8"
kv_cache_skip_steps: "0,1"
kv_cache_skip_layers: "0-2"
```

## Parameters

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `kv_cache_dtype` | str \| None | `None` | Set to `"fp8"` to enable dynamic FP8 FA on supported attention backends |
| `kv_cache_skip_steps` | str \| None | `None` | Denoising step selector to keep in native dtype, for example `"0,1,4-6"` |
| `kv_cache_skip_layers` | str \| None | `None` | Transformer layer selector to keep in native dtype, for example `"0-2,10"` |

Selectors use comma-separated integers and inclusive ranges. Listed steps or
layers skip FP8 FA; all other eligible full-attention forwards use the FP8 path.

## Validation and Notes

1. Compare generated images or videos against a BF16 baseline with the same
seed, prompt, resolution, frame count, and denoising steps.
2. Use `kv_cache_skip_steps` for denoising steps where quality is more
sensitive.
3. Use `kv_cache_skip_layers` for transformer layers that show visible quality
regressions.
4. Report both latency and quality results when enabling this option for a new
model. For image or video models, include visual comparison and quantitative
metrics when available, such as PSNR or SSIM.
24 changes: 24 additions & 0 deletions examples/offline_inference/image_to_video/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,24 @@ def parse_args() -> argparse.Namespace:
choices=["unipc", "euler"],
help="Sampling solver for Wan2.2 pipelines. Use 'euler' for Lightning/Distill setups.",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
default=None,
help="Config-level KV cache dtype (e.g. float8_e4m3fn).",
)
parser.add_argument(
"--kv-cache-skip-steps",
type=str,
default=None,
help="Config-level KV-cache quantization skip-step selector, e.g. '0-9,20,25-30'.",
)
parser.add_argument(
"--kv-cache-skip-layers",
type=str,
default=None,
help="Config-level KV-cache quantization skip-layer selector, e.g. '0,1,4-8'.",
)
parser.add_argument("--output", type=str, default="i2v_output.mp4", help="Path to save the video (mp4).")
parser.add_argument("--fps", type=int, default=None, help="Frames per second for the output video.")
parser.add_argument(
Expand Down Expand Up @@ -309,6 +327,9 @@ def main():
vae_use_tiling=args.vae_use_tiling,
boundary_ratio=args.boundary_ratio,
flow_shift=args.flow_shift,
kv_cache_dtype=args.kv_cache_dtype,
kv_cache_skip_steps=args.kv_cache_skip_steps,
kv_cache_skip_layers=args.kv_cache_skip_layers,
enable_cpu_offload=args.enable_cpu_offload,
parallel_config=parallel_config,
enforce_eager=args.enforce_eager,
Expand All @@ -330,6 +351,9 @@ def main():
print(f" Inference steps: {args.num_inference_steps}")
print(f" Frames: {args.num_frames}")
print(f" Solver: {args.sample_solver}")
print(f" kv_cache_dtype(config): {args.kv_cache_dtype}")
print(f" kv_cache_skip_steps(config): {args.kv_cache_skip_steps}")
print(f" kv_cache_skip_layers(config): {args.kv_cache_skip_layers}")
print(
f" Parallel configuration: cfg_parallel_size={args.cfg_parallel_size},"
f" tensor_parallel_size={args.tensor_parallel_size}, vae_patch_parallel_size={args.vae_patch_parallel_size}"
Expand Down
207 changes: 207 additions & 0 deletions tests/platforms/npu/quant/test_kv_quant_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for NPU FP8 KV quantization helpers.

These tests load ``kv_quant_npu`` from its source file via ``importlib`` so
the test module itself does not ``import vllm_omni`` (which would pull
``patch`` → ``aenum``, vLLM, etc.).
"""

from __future__ import annotations

import importlib.util
import math
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any

import pytest
import torch

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion, pytest.mark.cpu]


def _repo_root() -> Path:
"""Resolve checkout root (parent of ``vllm_omni/``), not ``tests/``."""
here = Path(__file__).resolve()
marker = Path("vllm_omni") / "platforms" / "npu" / "quant" / "kv_quant_npu.py"
for parent in here.parents:
if (parent / marker).is_file():
return parent
msg = f"could not locate repo root (no {marker}) starting from {here}"
raise FileNotFoundError(msg)


def _load_kv_quant_npu() -> ModuleType:
path = _repo_root() / "vllm_omni" / "platforms" / "npu" / "quant" / "kv_quant_npu.py"
if not path.is_file():
msg = f"kv_quant_npu source not found: {path}"
raise FileNotFoundError(msg)
name = "vllm_omni_test_kv_quant_npu_standalone"
spec = importlib.util.spec_from_file_location(name, path)
if spec is None or spec.loader is None:
msg = f"cannot load import spec for {path}"
raise RuntimeError(msg)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod


kv_quant_npu = _load_kv_quant_npu()


def _npu_smoke_available() -> bool:
try:
import torch_npu # noqa: F401
except ImportError:
return False
return bool(hasattr(torch, "npu") and torch.npu.is_available())


npu_smoke = pytest.mark.skipif(not _npu_smoke_available(), reason="NPU device or torch_npu not available.")


def test_is_quantized_kv_cache() -> None:
assert kv_quant_npu.is_quantized_kv_cache("fp8")
assert not kv_quant_npu.is_quantized_kv_cache(None)
assert not kv_quant_npu.is_quantized_kv_cache("int8")


class TestKVQuantNPUUnit:
@pytest.fixture(autouse=True)
def clear_rot_cache(self):
kv_quant_npu._ROT_MATRIXS.clear()

def test_get_rot_matrix_caches_by_device_dtype_and_head_dim(self) -> None:
calls = {"count": 0}

class FakeQuaRotMode:
HADAMARD = "hadamard"

def fake_create_rot(mode, head_dim, seed):
calls["count"] += 1
assert mode == FakeQuaRotMode.HADAMARD
assert seed == 425500
return torch.eye(head_dim, dtype=torch.float32)

device = torch.device("cpu")
rot_1 = kv_quant_npu._get_rot_matrix(device, torch.float16, 8, FakeQuaRotMode, fake_create_rot)
rot_2 = kv_quant_npu._get_rot_matrix(device, torch.float16, 8, FakeQuaRotMode, fake_create_rot)
rot_3 = kv_quant_npu._get_rot_matrix(device, torch.bfloat16, 8, FakeQuaRotMode, fake_create_rot)
rot_4 = kv_quant_npu._get_rot_matrix(device, torch.float16, 16, FakeQuaRotMode, fake_create_rot)

assert calls["count"] == 3
assert rot_1 is rot_2
assert rot_3.dtype == torch.bfloat16
assert rot_4.shape == (16, 16)

@pytest.fixture
def fake_quant_ops(self, monkeypatch: pytest.MonkeyPatch) -> dict[str, Any]:
captured: dict[str, Any] = {
"fa_calls": [],
"npu_kwargs": None,
"out_shape": None,
}

class FakeTorchNPU:
float8_e4m3fn = "fp8_marker"

@staticmethod
def npu_fused_infer_attention_score_v2(q, k, v, **kwargs):
del q, k, v
captured["npu_kwargs"] = kwargs
out_shape = captured["out_shape"]
return (torch.ones(out_shape, dtype=torch.float32),)

def fake_fa_block_quant_preprocess(x, block_size, dst_type, layout):
captured["fa_calls"].append(
{
"block_size": block_size,
"layout": layout,
"dst_type": dst_type,
"shape": tuple(x.shape),
}
)
scale = torch.full((1,), float(block_size), dtype=torch.float32)
return x, scale

fake_qua_rot_mode = SimpleNamespace(HADAMARD="hadamard")

def fake_create_rot(mode, head_dim, seed):
assert mode == "hadamard"
assert seed == 425500
return torch.eye(head_dim, dtype=torch.float32)

monkeypatch.setattr(
kv_quant_npu,
"_load_quant_ops",
lambda: (FakeTorchNPU, fake_fa_block_quant_preprocess, fake_qua_rot_mode, fake_create_rot),
)

return captured

@staticmethod
def _make_qkv(shape: tuple[int, int, int, int]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
query = torch.randn(*shape, dtype=torch.float32)
key = torch.randn(*shape, dtype=torch.float32)
value = torch.randn(*shape, dtype=torch.float32)
return query, key, value

@pytest.mark.parametrize(
"layout,input_shape,out_shape,softmax_scale,expected_scale",
[
("BNSD", (2, 3, 4, 8), (2, 3, 6, 8), None, 1.0 / math.sqrt(8)),
("BSND", (2, 4, 3, 8), (2, 6, 3, 8), 0.125, 0.125),
],
)
def test_fp8_rotate_quant_fa_layouts_scale_and_crop(
self,
fake_quant_ops: dict[str, Any],
layout: str,
input_shape: tuple[int, int, int, int],
out_shape: tuple[int, int, int, int],
softmax_scale: float | None,
expected_scale: float,
) -> None:
query, key, value = self._make_qkv(input_shape)
fake_quant_ops["out_shape"] = out_shape

out = kv_quant_npu.fp8_rotate_quant_fa(query, key, value, layout=layout, softmax_scale=softmax_scale)

assert out.shape == query.shape
assert out.dtype == query.dtype
assert fake_quant_ops["npu_kwargs"]["input_layout"] == layout
# BNSD: shape[1]==heads, BSND: shape[2]==heads.
expected_heads = input_shape[1] if layout == "BNSD" else input_shape[2]
assert fake_quant_ops["npu_kwargs"]["num_query_heads"] == expected_heads
assert fake_quant_ops["npu_kwargs"]["softmax_scale"] == pytest.approx(expected_scale)
assert [call["block_size"] for call in fake_quant_ops["fa_calls"]] == [128, 256, 256]

def test_fp8_rotate_quant_fa_invalid_layout_raises(self, fake_quant_ops) -> None:
query = torch.randn(1, 2, 3, 4, dtype=torch.float32)
key = torch.randn(1, 2, 3, 4, dtype=torch.float32)
value = torch.randn(1, 2, 3, 4, dtype=torch.float32)
fake_quant_ops["out_shape"] = (1, 2, 3, 4)

with pytest.raises(ValueError, match="unsupported layout"):
kv_quant_npu.fp8_rotate_quant_fa(query, key, value, layout="INVALID")


@npu_smoke
class TestKVQuantNPUSmoke:
"""Smoke tests using real torch_npu/mindiesd stack, only on NPU."""

def test_fp8_rotate_quant_fa_real_npu_shape_contract(self):
try:
kv_quant_npu._load_quant_ops.cache_clear()
kv_quant_npu._load_quant_ops()
except ImportError:
pytest.skip("NPU quant dependencies are not fully installed.")

query = torch.randn(1, 2, 4, 64, dtype=torch.float16, device="npu")
key = torch.randn(1, 2, 4, 64, dtype=torch.float16, device="npu")
value = torch.randn(1, 2, 4, 64, dtype=torch.float16, device="npu")

out = kv_quant_npu.fp8_rotate_quant_fa(query, key, value, layout="BNSD")
assert out.shape == query.shape
assert out.dtype == query.dtype
Loading
Loading