Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
3524dcb
[Quantization] Add FP8 KV quantization for diffusion attention layers
lishunyang12 Feb 20, 2026
beb07ba
Merge branch 'main' into fp8-kv-quantization
lishunyang12 Apr 2, 2026
57f00a7
[Quantization] Align FP8 attention with design doc: Q quantization, C…
lishunyang12 Apr 2, 2026
bd9931c
Add --kv-cache-dtype flag to text_to_video.py
lishunyang12 Apr 2, 2026
1ec3236
Fix FP8 attention not activating: lazy-resolve config from forward co…
lishunyang12 Apr 2, 2026
5a5ed10
Add debug logging for FP8 attention resolution
lishunyang12 Apr 2, 2026
841ad38
Wire kv_cache_dtype through default stage config
lishunyang12 Apr 2, 2026
aaf3930
Use vLLM fused CUDA kernel for FP8 quantization
lishunyang12 Apr 2, 2026
2c4e9ea
Fix dequant dtype: cast result to output_dtype after f32 scale multiply
lishunyang12 Apr 2, 2026
2934ed0
Skip FP8 quantization when padding mask is present
lishunyang12 Apr 2, 2026
ad37fdc
Add debug scripts for FA3 FP8 capability check and mask stats
lishunyang12 Apr 2, 2026
2d5906a
Fix FA3 FP8 descale param names and enable FP8 varlen path
lishunyang12 Apr 2, 2026
502f5aa
Expand debug script: micro-benchmarks, varlen FP8, layer breakdown
lishunyang12 Apr 2, 2026
4ef08e9
Fix total_mem -> total_memory for PyTorch 2.10+
lishunyang12 Apr 2, 2026
8f28879
Fix descale shape: FA3 requires (batch, num_kv_heads) not scalar
lishunyang12 Apr 2, 2026
6ca66ea
Add debug_shapes.py: log actual Q/K/V shapes during inference
lishunyang12 Apr 2, 2026
d6c0346
Fix debug_shapes.py: add __main__ guard for multiprocessing spawn
lishunyang12 Apr 2, 2026
6ef694b
Update benchmarks with actual HunyuanVideo shapes: 50345 tokens, 16 h…
lishunyang12 Apr 2, 2026
7773551
Implement delayed scaling: cache scales to skip amax on subsequent calls
lishunyang12 Apr 2, 2026
99bb651
Disable delayed scaling (green output bug), use dynamic only
lishunyang12 Apr 2, 2026
0a2b3e9
Test KV-only FP8: keep Q in BF16, only quantize K/V
lishunyang12 Apr 2, 2026
55256a1
Skip varlen for FP8: use regular FA3 path to avoid varlen descale bug
lishunyang12 Apr 2, 2026
8fa2ffa
Revert to QKV FP8: FA3 requires same dtype for Q/K/V
lishunyang12 Apr 2, 2026
8495dfa
Zero out padding K/V before FP8 quantization
lishunyang12 Apr 2, 2026
fef2814
Auto-fallback to BF16 for sequences >16K tokens
lishunyang12 Apr 2, 2026
68dd46c
Use num_splits for FP8 accuracy at long sequences instead of seqlen c…
lishunyang12 Apr 2, 2026
30180bb
Add debug_fa3_version.py: check FA3 builds and two-level accumulation
lishunyang12 Apr 2, 2026
fb83d59
Add debug_vllm_fa.py: check vLLM's bundled flash-attention for FP8 fix
lishunyang12 Apr 2, 2026
aab93f4
Switch FP8 path to vLLM's bundled FA3 (has two-level accumulation fix)
lishunyang12 Apr 2, 2026
27198f1
Merge branch 'main' into fp8-kv-quantization
lishunyang12 Apr 5, 2026
1838fc1
Optimize FP8 attention: zero-overhead quantization, direct FA3 path
lishunyang12 Apr 6, 2026
529b7fb
Move FP8 quantization into attention backends with per-platform frame…
lishunyang12 Apr 7, 2026
385edbf
Merge main, resolve async_omni_engine conflict
lishunyang12 Apr 7, 2026
a275bea
Use table-driven platform dispatch for attention forward
lishunyang12 Apr 7, 2026
fc095e0
Scope FP8 KV support to CUDA only, comment placeholders for other pla…
lishunyang12 Apr 7, 2026
88bd707
Keep dispatch table complete, silence CUDA kernel warning for non-CUD…
lishunyang12 Apr 7, 2026
a9a5037
Fix: skip CUDA fused quant kernel on non-CUDA tensors
lishunyang12 Apr 7, 2026
68bc96e
Add tests for fast quant, backend support, and per-platform dtype guard
lishunyang12 Apr 7, 2026
54e98d9
Update vllm_omni/diffusion/attention/backends/abstract.py
lishunyang12 Apr 8, 2026
b45e179
Add SageAttention vs FlashAttention benchmark script
lishunyang12 Apr 9, 2026
087dc09
Fix model name to HunyuanVideo-1.5-Diffusers-480p_t2v
lishunyang12 Apr 9, 2026
4da4f94
Add attention kernel benchmark (FA vs Sage vs SDPA)
lishunyang12 Apr 9, 2026
dd18ba3
Support fa3_fwd_interface in kernel benchmark
lishunyang12 Apr 9, 2026
19c7df6
Rewrite kernel bench to match SageAttention official style
lishunyang12 Apr 9, 2026
c20feec
Rewrite bench to exactly follow SageAttention official bench
lishunyang12 Apr 9, 2026
ebf0771
Inline benchmark_forward to remove flash_attn dependency
lishunyang12 Apr 9, 2026
b3fe3f6
Add sageattn method, --dtype flag, fix FA3 BF16 support
lishunyang12 Apr 9, 2026
8d5bca8
Default to SageAttention when available on CUDA
lishunyang12 Apr 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions benchmarks/diffusion/bench_attn_kernel.py

Large diffs are not rendered by default.

35 changes: 35 additions & 0 deletions benchmarks/diffusion/bench_sage_comparison.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# Benchmark: HunyuanVideo 1.5 480p — BF16 baseline vs SageAttention
# Resolution: 480×832, 33 frames
set -e

MODEL="hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v"
PROMPT="A serene lakeside sunrise with mist over the water."
SCRIPT="examples/offline_inference/text_to_video/text_to_video.py"
OUTPUT_DIR="${OUTPUT_DIR:-/workspace}"

COMMON_ARGS="--model $MODEL \
--height 480 --width 832 --num-frames 33 \
--num-inference-steps 50 \
--guidance-scale 6.0 \
--seed 42 \
--vae-use-tiling \
--enforce-eager"

echo "============================================"
echo "=== 1/2: BF16 + FlashAttention (baseline)==="
echo "============================================"
DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN \
python $SCRIPT $COMMON_ARGS \
--output "$OUTPUT_DIR/output_flash_attn.mp4"

echo ""
echo "============================================"
echo "=== 2/2: BF16 + SageAttention ==="
echo "============================================"
DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN \
python $SCRIPT $COMMON_ARGS \
--output "$OUTPUT_DIR/output_sage_attn.mp4"

echo ""
echo "=== Done. Compare: output_flash_attn.mp4 vs output_sage_attn.mp4 ==="
18 changes: 16 additions & 2 deletions examples/offline_inference/text_to_image/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def parse_args() -> argparse.Namespace:
"Available layers: to_qkv, to_out, add_kv_proj, to_add_out, img_mlp, txt_mlp, proj_out. "
"Example: --ignored-layers 'add_kv_proj,to_add_out'",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
default="auto",
choices=["auto", "fp8"],
help="Data type for attention Q/K/V quantization. "
"'fp8': dynamically quantize to float8_e4m3fn each forward pass. "
"On Hopper GPUs with FA3, enables native FP8 attention compute. "
"On other backends (FA2/SDPA), tensors are dequantized before the kernel. "
"'auto': no quantization (default).",
)
parser.add_argument(
"--vae-use-slicing",
action="store_true",
Expand Down Expand Up @@ -313,10 +324,10 @@ def main():
lora_args["lora_path"] = args.lora_path
print(f"Using LoRA from: {args.lora_path}")

# Build quantization kwargs: use quantization_config dict when
# ignored_layers is specified so the list flows through OmniDiffusionConfig
# Build quantization kwargs
quant_kwargs: dict[str, Any] = {}
ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None
kv_cache_dtype = args.kv_cache_dtype if args.kv_cache_dtype != "auto" else None
if args.quantization == "gguf":
if not args.gguf_model:
raise ValueError("--gguf-model is required when --quantization gguf is set.")
Expand Down Expand Up @@ -346,6 +357,7 @@ def main():
"mode": "text-to-image",
"log_stats": args.log_stats,
"enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler,
"kv_cache_dtype": kv_cache_dtype,
**lora_args,
**quant_kwargs,
}
Expand All @@ -367,6 +379,8 @@ def main():
print(f" Inference steps: {args.num_inference_steps}")
print(f" Cache backend: {cache_backend if cache_backend else 'None (no acceleration)'}")
print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}")
if kv_cache_dtype:
print(f" KV cache dtype: {kv_cache_dtype}")
if ignored_layers:
print(f" Ignored layers: {ignored_layers}")
print(
Expand Down
13 changes: 13 additions & 0 deletions examples/offline_inference/text_to_video/text_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,16 @@ def parse_args() -> argparse.Namespace:
choices=["fp8", "gguf"],
help="Quantization method for the transformer (fp8 for online FP8 quantization).",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
default="auto",
choices=["auto", "fp8"],
help="Data type for attention Q/K/V quantization. "
"'fp8': dynamically quantize to float8_e4m3fn each forward pass. "
"On Hopper GPUs with FA3, enables native FP8 attention compute. "
"'auto': no quantization (default).",
)
return parser.parse_args()


Expand Down Expand Up @@ -227,6 +237,8 @@ def main():
# Check if profiling is requested via environment variable
profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR"))

kv_cache_dtype = args.kv_cache_dtype if args.kv_cache_dtype != "auto" else None

omni_kwargs = dict(
model=args.model,
enable_layerwise_offload=args.enable_layerwise_offload,
Expand All @@ -239,6 +251,7 @@ def main():
cache_backend=args.cache_backend,
cache_config=cache_config,
enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler,
kv_cache_dtype=kv_cache_dtype,
)
if args.boundary_ratio is not None:
omni_kwargs["boundary_ratio"] = args.boundary_ratio
Expand Down
233 changes: 233 additions & 0 deletions tests/diffusion/quantization/test_kv_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for FP8 Q/K/V quantization utilities."""

import pytest
import torch

pytestmark = [pytest.mark.core_model, pytest.mark.diffusion]


def test_qkv_roundtrip_preserves_values():
"""quantize_qkv_fp8 -> dequantize_fp8 should preserve values within FP8 tolerance."""
from vllm_omni.quantization.kv_quant import (
dequantize_fp8,
quantize_qkv_fp8,
)

torch.manual_seed(42)
query = torch.randn(2, 128, 8, 64, dtype=torch.bfloat16)
key = torch.randn(2, 128, 8, 64, dtype=torch.bfloat16)
value = torch.randn(2, 128, 8, 64, dtype=torch.bfloat16)

fp8_q, fp8_k, fp8_v, q_scale, k_scale, v_scale = quantize_qkv_fp8(
query, key, value
)

assert fp8_q.dtype == torch.float8_e4m3fn
assert fp8_k.dtype == torch.float8_e4m3fn
assert fp8_v.dtype == torch.float8_e4m3fn
assert q_scale.numel() == 1
assert k_scale.numel() == 1
assert v_scale.numel() == 1

query_rt = dequantize_fp8(fp8_q, q_scale, torch.bfloat16)
key_rt = dequantize_fp8(fp8_k, k_scale, torch.bfloat16)
value_rt = dequantize_fp8(fp8_v, v_scale, torch.bfloat16)

# FP8 e4m3 has ~0.1% relative error for typical values
torch.testing.assert_close(query_rt, query, rtol=0.05, atol=0.05)
torch.testing.assert_close(key_rt, key, rtol=0.05, atol=0.05)
torch.testing.assert_close(value_rt, value, rtol=0.05, atol=0.05)


def test_kv_only_roundtrip():
"""quantize_kv_fp8 for joint attention path."""
from vllm_omni.quantization.kv_quant import (
dequantize_fp8,
quantize_kv_fp8,
)

torch.manual_seed(42)
key = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)
value = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)

fp8_k, fp8_v, k_scale, v_scale = quantize_kv_fp8(key, value)

assert fp8_k.dtype == torch.float8_e4m3fn
assert k_scale > 0
assert v_scale > 0

key_rt = dequantize_fp8(fp8_k, k_scale, torch.bfloat16)
torch.testing.assert_close(key_rt, key, rtol=0.05, atol=0.05)


def test_scales_are_positive():
from vllm_omni.quantization.kv_quant import quantize_qkv_fp8

q = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)
k = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)
v = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)

_, _, _, q_scale, k_scale, v_scale = quantize_qkv_fp8(q, k, v)
assert q_scale > 0
assert k_scale > 0
assert v_scale > 0


def test_zero_tensor():
"""All-zero input should not produce NaN or Inf."""
from vllm_omni.quantization.kv_quant import (
dequantize_fp8,
quantize_qkv_fp8,
)

q = torch.zeros(1, 16, 4, 32, dtype=torch.bfloat16)
k = torch.zeros(1, 16, 4, 32, dtype=torch.bfloat16)
v = torch.zeros(1, 16, 4, 32, dtype=torch.bfloat16)

fp8_q, fp8_k, fp8_v, q_s, k_s, v_s = quantize_qkv_fp8(q, k, v)
q_rt = dequantize_fp8(fp8_q, q_s, torch.bfloat16)
k_rt = dequantize_fp8(fp8_k, k_s, torch.bfloat16)

assert not torch.isnan(q_rt).any()
assert not torch.isnan(k_rt).any()
assert torch.allclose(q_rt, q)
assert torch.allclose(k_rt, k)


def test_fp16_input():
"""Should work with float16 input as well."""
from vllm_omni.quantization.kv_quant import quantize_qkv_fp8

q = torch.randn(1, 32, 4, 64, dtype=torch.float16)
k = torch.randn(1, 32, 4, 64, dtype=torch.float16)
v = torch.randn(1, 32, 4, 64, dtype=torch.float16)

fp8_q, fp8_k, fp8_v, _, _, _ = quantize_qkv_fp8(q, k, v)
assert fp8_q.dtype == torch.float8_e4m3fn
assert fp8_k.dtype == torch.float8_e4m3fn
assert fp8_v.dtype == torch.float8_e4m3fn


def test_kv_cache_dtype_config_field():
"""OmniDiffusionConfig should accept kv_cache_dtype field."""
from vllm_omni.diffusion.data import OmniDiffusionConfig

config = OmniDiffusionConfig(model="test", kv_cache_dtype="fp8")
assert config.kv_cache_dtype == "fp8"

config_default = OmniDiffusionConfig(model="test")
assert config_default.kv_cache_dtype is None


def test_is_quantized_kv_cache():
"""is_quantized_kv_cache should detect FP8 dtype strings."""
from vllm_omni.quantization.kv_quant import is_quantized_kv_cache

assert is_quantized_kv_cache("fp8") is True
assert is_quantized_kv_cache("fp8_e4m3") is True
assert is_quantized_kv_cache(None) is False
assert is_quantized_kv_cache("auto") is False
assert is_quantized_kv_cache("bfloat16") is False


def test_attention_metadata_kv_cache_dtype():
"""AttentionMetadata should have kv_cache_dtype field."""
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata

meta = AttentionMetadata()
assert meta.kv_cache_dtype is None

meta.kv_cache_dtype = "fp8"
assert meta.kv_cache_dtype == "fp8"


def test_fast_qkv_quantization():
"""quantize_qkv_fp8_fast should use scale=1.0 (direct cast)."""
from vllm_omni.quantization.kv_quant import quantize_qkv_fp8_fast

q = torch.randn(1, 32, 4, 64, dtype=torch.bfloat16)
k = torch.randn(1, 32, 4, 64, dtype=torch.bfloat16)
v = torch.randn(1, 32, 4, 64, dtype=torch.bfloat16)

fp8_q, fp8_k, fp8_v, q_s, k_s, v_s = quantize_qkv_fp8_fast(q, k, v)

assert fp8_q.dtype == torch.float8_e4m3fn
assert fp8_k.dtype == torch.float8_e4m3fn
assert fp8_v.dtype == torch.float8_e4m3fn
# Fast path uses scale=1.0
assert q_s.item() == 1.0
assert k_s.item() == 1.0
assert v_s.item() == 1.0


def test_fast_kv_quantization():
"""quantize_kv_fp8_fast for joint attention path."""
from vllm_omni.quantization.kv_quant import quantize_kv_fp8_fast

k = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)
v = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16)

fp8_k, fp8_v, k_s, v_s = quantize_kv_fp8_fast(k, v)

assert fp8_k.dtype == torch.float8_e4m3fn
assert fp8_v.dtype == torch.float8_e4m3fn
assert k_s.item() == 1.0
assert v_s.item() == 1.0


def test_flash_backend_supports_kv_cache_dtype():
"""FlashAttentionBackend should declare FP8 support."""
from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionBackend

assert FlashAttentionBackend.supports_kv_cache_dtype(None) is True
assert FlashAttentionBackend.supports_kv_cache_dtype("fp8") is True
assert FlashAttentionBackend.supports_kv_cache_dtype("fp8_e4m3") is True
assert FlashAttentionBackend.supports_kv_cache_dtype("mxfp8") is False


def test_sdpa_backend_does_not_support_fp8():
"""SDPABackend should not declare FP8 support."""
from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend

assert SDPABackend.supports_kv_cache_dtype(None) is True
assert SDPABackend.supports_kv_cache_dtype("fp8") is False


def test_handle_kv_cache_dtype_clears_unsupported():
"""_handle_kv_cache_dtype should clear unsupported dtype to None."""
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl

impl = SDPAImpl(num_heads=4, head_size=64, softmax_scale=0.125)
meta = AttentionMetadata(kv_cache_dtype="fp8")

# SDPA has empty _supported_kv_cache_dtypes, should clear fp8
impl._handle_kv_cache_dtype(meta, "cuda")
assert meta.kv_cache_dtype is None


def test_handle_kv_cache_dtype_preserves_supported():
"""_handle_kv_cache_dtype should preserve supported dtype."""
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl

impl = FlashAttentionImpl(num_heads=4, head_size=64, softmax_scale=0.125)
meta = AttentionMetadata(kv_cache_dtype="fp8")

impl._handle_kv_cache_dtype(meta, "cuda")
assert meta.kv_cache_dtype == "fp8"


def test_handle_kv_cache_dtype_clears_unsupported_platform():
"""FP8 on FlashAttention should be cleared for non-CUDA platforms."""
from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata
from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl

impl = FlashAttentionImpl(num_heads=4, head_size=64, softmax_scale=0.125)
meta = AttentionMetadata(kv_cache_dtype="fp8")

# NPU not in FlashAttentionImpl._supported_kv_cache_dtypes
impl._handle_kv_cache_dtype(meta, "npu")
assert meta.kv_cache_dtype is None
Loading