diff --git a/benchmarks/diffusion/bench_attn_kernel.py b/benchmarks/diffusion/bench_attn_kernel.py new file mode 100644 index 00000000000..8d3ec502fad --- /dev/null +++ b/benchmarks/diffusion/bench_attn_kernel.py @@ -0,0 +1,352 @@ +""" +Reimplemented from SageAttention official bench: +https://github.com/thu-ml/SageAttention/tree/main/bench + +Scripts: + bench_baseline.py -> --method fa2/torch/xformers + bench_fa3.py -> --method fa3 + bench_qk_int8_pv_fp16_cuda.py -> --method sage_int8_fp16_cuda + bench_qk_int8_pv_fp16_triton.py -> --method sage_int8_fp16_triton + bench_qk_int8_pv_fp8_cuda.py -> --method sage_int8_fp8_cuda (SM89, RTX 4090) + bench_qk_int8_pv_fp8_cuda_sm90.py -> --method sage_int8_fp8_cuda_sm90 (H100) + +Usage: + python bench_attn_kernel.py --method fa3 --dtype bfloat16 + python bench_attn_kernel.py --method sageattn --dtype bfloat16 + python bench_attn_kernel.py --method fa2 + python bench_attn_kernel.py --method torch + python bench_attn_kernel.py --method sage_int8_fp16_cuda + python bench_attn_kernel.py --method sage_int8_fp16_triton + python bench_attn_kernel.py --method sage_int8_fp8_cuda + python bench_attn_kernel.py --method sage_int8_fp8_cuda_sm90 +""" + +import argparse +import re +import subprocess + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward(fn, *inputs, repeats=100, desc="", verbose=False, **kwinputs): + """Reimplemented from flash_attn.utils.benchmark.benchmark_forward + so we don't need flash_attn installed just for the timer.""" + t = benchmark.Timer( + stmt="fn(*inputs, **kwinputs)", + globals={"fn": fn, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(desc, "- Forward pass") + print(m) + return t, m + + +def get_cuda_version(): + try: + output = subprocess.check_output(['nvcc', '--version']).decode() + match = re.search(r'release (\d+)\.(\d+)', output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None + + +parser = argparse.ArgumentParser(description='Attention Kernel Benchmark (SageAttention official style)') +parser.add_argument('--method', type=str, default='fa3', + choices=['fa2', 'torch', 'xformers', 'fa3', 'sageattn', + 'sage_int8_fp16_cuda', 'sage_int8_fp16_triton', + 'sage_int8_fp8_cuda', 'sage_int8_fp8_cuda_sm90']) +parser.add_argument('--batch_size', type=int, default=4, help='Batch size') +parser.add_argument('--num_heads', type=int, default=32, help='Number of heads') +parser.add_argument('--head_dim', type=int, default=128, help='Head dimension') +parser.add_argument('--quant_gran', type=str, default='per_warp', choices=['per_warp', 'per_thread'], + help='Quantization granularity (sage kernels only)') +parser.add_argument('--pv_accum_dtype', type=str, default=None, + help='PV accumulation dtype (sage kernels only)') +parser.add_argument('--dtype', type=str, default='float16', choices=['float16', 'bfloat16'], + help='Data type for FA3/sageattn/baseline (default: float16)') +args = parser.parse_args() + +head = args.num_heads +batch = args.batch_size +headdim = args.head_dim +dtype = getattr(torch, args.dtype) + +# ============================================================ +# bench_baseline: fa2 / torch / xformers +# ============================================================ +if args.method in ('fa2', 'torch', 'xformers'): + from torch.nn.functional import scaled_dot_product_attention as sdpa + + torch.backends.cuda.enable_flash_sdp(args.method == 'fa2') + torch.backends.cuda.enable_math_sdp(args.method == 'torch') + torch.backends.cuda.enable_mem_efficient_sdp(args.method == 'xformers') + + print(f"Baseline: {args.method}") + print(f"batch: {batch}, head: {head}, headdim: {headdim}, dtype: {args.dtype}") + + for is_causal in [False, True]: + print(f"is_causal: {is_causal}") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) + q = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") + k = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") + v = torch.randn(batch, head, seq_len, headdim, dtype=dtype, device="cuda") + for i in range(5): sdpa(q, k, v, is_causal=is_causal) + torch.cuda.synchronize() + _, time = benchmark_forward(sdpa, q, k, v, is_causal=is_causal, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + +# ============================================================ +# bench_fa3 +# ============================================================ +elif args.method == 'fa3': + # Try fa3_fwd_interface first (vllm-omni custom build), then flash_attn_interface + flash_attn_func_v3 = None + fa3_source = None + for mod_name in ['fa3_fwd_interface', 'flash_attn_interface']: + try: + mod = __import__(mod_name, fromlist=['flash_attn_func']) + flash_attn_func_v3 = getattr(mod, 'flash_attn_func') + fa3_source = mod_name + break + except (ImportError, AttributeError): + continue + + if flash_attn_func_v3 is None: + raise ImportError("Neither fa3_fwd_interface nor flash_attn_interface found. Install FA3.") + + print(f"FlashAttention3 Benchmark (source: {fa3_source})") + print(f"batch: {batch}, head: {head}, headdim: {headdim}, dtype: {args.dtype}") + + for is_causal in [False, True]: + print(f"is_causal: {is_causal}") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) + q = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + k = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + for i in range(5): flash_attn_func_v3(q, k, v, causal=is_causal) + torch.cuda.synchronize() + _, time = benchmark_forward(flash_attn_func_v3, q, k, v, causal=is_causal, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + +# ============================================================ +# bench sageattn high-level API (what vllm-omni actually calls) +# ============================================================ +elif args.method == 'sageattn': + from sageattention import sageattn + + print(f"SageAttention (sageattn high-level API) Benchmark") + print(f"batch: {batch}, head: {head}, headdim: {headdim}, dtype: {args.dtype}") + + for is_causal in [False, True]: + print(f"is_causal: {is_causal}") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len // (2 if is_causal else 1) + q = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + k = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + v = torch.randn(batch, seq_len, head, headdim, dtype=dtype, device="cuda") + for i in range(5): sageattn(q, k, v, tensor_layout="NHD", is_causal=is_causal) + torch.cuda.synchronize() + _, time = benchmark_forward(sageattn, q, k, v, tensor_layout="NHD", is_causal=is_causal, repeats=100, verbose=False, desc='SageAttn') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + +# ============================================================ +# bench_qk_int8_pv_fp16_cuda +# ============================================================ +elif args.method == 'sage_int8_fp16_cuda': + import sageattention._qattn_sm80 as qattn + + pv_accum = args.pv_accum_dtype or 'fp16' + assert pv_accum in ('fp16', 'fp16+fp32', 'fp32') + + WARP_Q = 16 if (headdim == 128 and pv_accum == "fp16+fp32") else 32 + WARP_K = 64 + + if pv_accum == 'fp32': + kernel = qattn.qk_int8_sv_f16_accum_f32_attn + elif pv_accum == 'fp16+fp32': + kernel = qattn.qk_int8_sv_f16_accum_f16_attn_inst_buf + elif pv_accum == 'fp16': + kernel = qattn.qk_int8_sv_f16_accum_f16_attn + + _qk_quant_gran = 3 if args.quant_gran == 'per_thread' else 2 + + print(f"CUDA QK Int8 PV FP16 Benchmark") + print(f"batch: {batch}, head: {head}, headdim: {headdim}, pv_accum_dtype: {pv_accum}") + + for is_causal in [False, True]: + _is_causal = 1 if is_causal else 0 + print(f"is_causal: {is_causal}") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1) + + q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") + k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") + + if args.quant_gran == 'per_warp': + q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float, device="cuda") + elif args.quant_gran == 'per_thread': + q_scale = torch.randn(batch, head, seq_len // WARP_Q * 8, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // WARP_K * 4, dtype=torch.float, device="cuda") + + v = torch.randn(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") + o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") + sm_scale = 1 / (headdim ** 0.5) + for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 0, _is_causal, _qk_quant_gran, sm_scale, 0) + torch.cuda.synchronize() + _, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 0, _is_causal, _qk_quant_gran, sm_scale, 0, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + +# ============================================================ +# bench_qk_int8_pv_fp16_triton +# ============================================================ +elif args.method == 'sage_int8_fp16_triton': + from sageattention.triton.attn_qk_int8_per_block import forward + from sageattention.triton.attn_qk_int8_per_block_causal import forward as forward_causal + + print(f"Triton QK Int8 PV FP16 Benchmark") + print(f"batch_size: {batch}, num_heads: {head}, head_dim: {headdim}") + + # non-causal + print("is_causal: False") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len + + q = torch.randint(-100, 100, (batch, head, seq_len, headdim), dtype=torch.int8, device='cuda') + k = torch.randint(-100, 100, (batch, head, seq_len, headdim), dtype=torch.int8, device='cuda') + v = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device='cuda') + + q_scale = torch.randn(batch, head, (seq_len // 128), 1, dtype=torch.float16, device='cuda') + k_scale = torch.randn(batch, head, (seq_len // 64), 1, dtype=torch.float16, device='cuda') + + for i in range(5): forward(q, k, v, q_scale, k_scale, output_dtype=torch.bfloat16) + torch.cuda.synchronize() + _, time = benchmark_forward(forward, q, k, v, q_scale, k_scale, output_dtype=torch.bfloat16, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + + # causal + print("is_causal: True") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len // 2 + + q = torch.randint(-100, 100, (batch, head, seq_len, headdim), dtype=torch.int8, device='cuda') + k = torch.randint(-100, 100, (batch, head, seq_len, headdim), dtype=torch.int8, device='cuda') + v = torch.randn(batch, head, seq_len, headdim, dtype=torch.float16, device='cuda') + + q_scale = torch.randn(batch, head, (seq_len // 128), 1, dtype=torch.float16, device='cuda') + k_scale = torch.randn(batch, head, (seq_len // 64), 1, dtype=torch.float16, device='cuda') + + for i in range(5): forward_causal(q, k, v, q_scale, k_scale, output_dtype=torch.bfloat16) + torch.cuda.synchronize() + _, time = benchmark_forward(forward_causal, q, k, v, q_scale, k_scale, output_dtype=torch.bfloat16, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + +# ============================================================ +# bench_qk_int8_pv_fp8_cuda (SM89 / RTX 4090) +# ============================================================ +elif args.method == 'sage_int8_fp8_cuda': + import sageattention._qattn_sm89 as qattn + + pv_accum = args.pv_accum_dtype or 'fp32+fp16' + assert pv_accum in ('fp32', 'fp32+fp32', 'fp32+fp16') + + cuda_major, cuda_minor = get_cuda_version() + if (cuda_major, cuda_minor) < (12, 8) and pv_accum == 'fp32+fp16': + print("=============\n NOTE: cuda version < 12.8, not support pv_accum_dtype fp32+fp16.") + print(" Switch to 'fp32+fp32' automatically\n=============") + pv_accum = 'fp32+fp32' + + WARP_Q = 32 + WARP_K = 64 + + if pv_accum == 'fp32': + kernel = qattn.qk_int8_sv_f8_accum_f32_attn + elif pv_accum == 'fp32+fp32': + kernel = qattn.qk_int8_sv_f8_accum_f32_attn_inst_buf + elif pv_accum == 'fp32+fp16': + kernel = qattn.qk_int8_sv_f8_accum_f16_attn_inst_buf + + _qk_quant_gran = 3 if args.quant_gran == 'per_thread' else 2 + + print(f"CUDA QK Int8 PV FP8 Benchmark (SM89)") + print(f"batch: {batch}, head: {head}, headdim: {headdim}, pv_accum_dtype: {pv_accum}") + + for is_causal in [False, True]: + _is_causal = 1 if is_causal else 0 + print(f"is_causal: {is_causal}") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1) + + q = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") + k = torch.randint(-95, 95, (batch, seq_len, head, headdim), dtype=torch.int8, device="cuda") + o = torch.empty(batch, seq_len, head, headdim, dtype=torch.float16, device="cuda") + + vm = torch.randn(batch, head, headdim, dtype=torch.float, device="cuda") + v_scale = torch.randn(batch, head, headdim, dtype=torch.float, device="cuda") + + if args.quant_gran == 'per_warp': + q_scale = torch.randn(batch, head, seq_len // WARP_Q, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // WARP_K, dtype=torch.float, device="cuda") + elif args.quant_gran == 'per_thread': + q_scale = torch.randn(batch, head, seq_len // WARP_Q * 8, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // WARP_K * 4, dtype=torch.float, device="cuda") + + v = torch.randn(batch, headdim, head, seq_len, dtype=torch.float16, device="cuda").to(torch.float8_e4m3fn) + sm_scale = 1 / (headdim ** 0.5) + for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 0, _is_causal, _qk_quant_gran, sm_scale, 0) + torch.cuda.synchronize() + _, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 0, _is_causal, _qk_quant_gran, sm_scale, 0, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') + +# ============================================================ +# bench_qk_int8_pv_fp8_cuda_sm90 (H100) +# ============================================================ +elif args.method == 'sage_int8_fp8_cuda_sm90': + import sageattention._qattn_sm90 as qattn + + pv_accum = args.pv_accum_dtype or 'fp32+fp32' + assert pv_accum == 'fp32+fp32', "pure fp32 accumulator is not supported for now" + + WARP_Q = 32 + WARP_K = 64 + + kernel = qattn.qk_int8_sv_f8_accum_f32_attn_inst_buf + + _qk_quant_gran = 3 if args.quant_gran == 'per_thread' else 2 + + print(f"CUDA QK Int8 PV FP8 SM90 Benchmark") + print(f"batch: {batch}, head: {head}, headdim: {headdim}") + + for is_causal in [False, True]: + _is_causal = 1 if is_causal else 0 + print(f"is_causal: {is_causal}") + for seq_len in sorted({1024, 2048, 4096, 8192, 16384, 32768}): + flops = 4 * head * batch * headdim * seq_len * seq_len / (2 if is_causal else 1) + + q = torch.randint(-95, 95, (batch, head, seq_len, headdim), dtype=torch.int8, device="cuda") + k = torch.randint(-95, 95, (batch, head, seq_len, headdim), dtype=torch.int8, device="cuda") + o = torch.empty(batch, head, seq_len, headdim, dtype=torch.float16, device="cuda") + + v_scale = torch.randn(batch, head, headdim, dtype=torch.float, device="cuda") + + if args.quant_gran == 'per_warp': + q_scale = torch.randn(batch, head, seq_len // 64 * 4, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // 128, dtype=torch.float, device="cuda") + elif args.quant_gran == 'per_thread': + q_scale = torch.randn(batch, head, seq_len // 64 * 4 * 8, dtype=torch.float, device="cuda") + k_scale = torch.randn(batch, head, seq_len // 128 * 4, dtype=torch.float, device="cuda") + + v = torch.randn(batch, head, headdim, seq_len, dtype=torch.float16, device="cuda").to(torch.float8_e4m3fn) + sm_scale = 1 / (headdim ** 0.5) + for i in range(5): kernel(q, k, v, o, q_scale, k_scale, 1, _is_causal, _qk_quant_gran, sm_scale, 0) + torch.cuda.synchronize() + _, time = benchmark_forward(kernel, q, k, v, o, q_scale, k_scale, 1, _is_causal, _qk_quant_gran, sm_scale, 0, repeats=100, verbose=False, desc='Triton') + print(f'{seq_len} flops:{flops/time.mean*1e-12}') diff --git a/benchmarks/diffusion/bench_sage_comparison.sh b/benchmarks/diffusion/bench_sage_comparison.sh new file mode 100644 index 00000000000..0c5b20fc4c9 --- /dev/null +++ b/benchmarks/diffusion/bench_sage_comparison.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Benchmark: HunyuanVideo 1.5 480p — BF16 baseline vs SageAttention +# Resolution: 480×832, 33 frames +set -e + +MODEL="hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v" +PROMPT="A serene lakeside sunrise with mist over the water." +SCRIPT="examples/offline_inference/text_to_video/text_to_video.py" +OUTPUT_DIR="${OUTPUT_DIR:-/workspace}" + +COMMON_ARGS="--model $MODEL \ + --height 480 --width 832 --num-frames 33 \ + --num-inference-steps 50 \ + --guidance-scale 6.0 \ + --seed 42 \ + --vae-use-tiling \ + --enforce-eager" + +echo "============================================" +echo "=== 1/2: BF16 + FlashAttention (baseline)===" +echo "============================================" +DIFFUSION_ATTENTION_BACKEND=FLASH_ATTN \ + python $SCRIPT $COMMON_ARGS \ + --output "$OUTPUT_DIR/output_flash_attn.mp4" + +echo "" +echo "============================================" +echo "=== 2/2: BF16 + SageAttention ===" +echo "============================================" +DIFFUSION_ATTENTION_BACKEND=SAGE_ATTN \ + python $SCRIPT $COMMON_ARGS \ + --output "$OUTPUT_DIR/output_sage_attn.mp4" + +echo "" +echo "=== Done. Compare: output_flash_attn.mp4 vs output_sage_attn.mp4 ===" diff --git a/examples/offline_inference/text_to_image/text_to_image.py b/examples/offline_inference/text_to_image/text_to_image.py index 42e44abb890..3c6f715486e 100644 --- a/examples/offline_inference/text_to_image/text_to_image.py +++ b/examples/offline_inference/text_to_image/text_to_image.py @@ -168,6 +168,17 @@ def parse_args() -> argparse.Namespace: "Available layers: to_qkv, to_out, add_kv_proj, to_add_out, img_mlp, txt_mlp, proj_out. " "Example: --ignored-layers 'add_kv_proj,to_add_out'", ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8"], + help="Data type for attention Q/K/V quantization. " + "'fp8': dynamically quantize to float8_e4m3fn each forward pass. " + "On Hopper GPUs with FA3, enables native FP8 attention compute. " + "On other backends (FA2/SDPA), tensors are dequantized before the kernel. " + "'auto': no quantization (default).", + ) parser.add_argument( "--vae-use-slicing", action="store_true", @@ -313,10 +324,10 @@ def main(): lora_args["lora_path"] = args.lora_path print(f"Using LoRA from: {args.lora_path}") - # Build quantization kwargs: use quantization_config dict when - # ignored_layers is specified so the list flows through OmniDiffusionConfig + # Build quantization kwargs quant_kwargs: dict[str, Any] = {} ignored_layers = [s.strip() for s in args.ignored_layers.split(",") if s.strip()] if args.ignored_layers else None + kv_cache_dtype = args.kv_cache_dtype if args.kv_cache_dtype != "auto" else None if args.quantization == "gguf": if not args.gguf_model: raise ValueError("--gguf-model is required when --quantization gguf is set.") @@ -346,6 +357,7 @@ def main(): "mode": "text-to-image", "log_stats": args.log_stats, "enable_diffusion_pipeline_profiler": args.enable_diffusion_pipeline_profiler, + "kv_cache_dtype": kv_cache_dtype, **lora_args, **quant_kwargs, } @@ -367,6 +379,8 @@ def main(): print(f" Inference steps: {args.num_inference_steps}") print(f" Cache backend: {cache_backend if cache_backend else 'None (no acceleration)'}") print(f" Quantization: {args.quantization if args.quantization else 'None (BF16)'}") + if kv_cache_dtype: + print(f" KV cache dtype: {kv_cache_dtype}") if ignored_layers: print(f" Ignored layers: {ignored_layers}") print( diff --git a/examples/offline_inference/text_to_video/text_to_video.py b/examples/offline_inference/text_to_video/text_to_video.py index 322911c993d..a3e251331d7 100644 --- a/examples/offline_inference/text_to_video/text_to_video.py +++ b/examples/offline_inference/text_to_video/text_to_video.py @@ -185,6 +185,16 @@ def parse_args() -> argparse.Namespace: choices=["fp8", "gguf"], help="Quantization method for the transformer (fp8 for online FP8 quantization).", ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8"], + help="Data type for attention Q/K/V quantization. " + "'fp8': dynamically quantize to float8_e4m3fn each forward pass. " + "On Hopper GPUs with FA3, enables native FP8 attention compute. " + "'auto': no quantization (default).", + ) return parser.parse_args() @@ -227,6 +237,8 @@ def main(): # Check if profiling is requested via environment variable profiler_enabled = bool(os.getenv("VLLM_TORCH_PROFILER_DIR")) + kv_cache_dtype = args.kv_cache_dtype if args.kv_cache_dtype != "auto" else None + omni_kwargs = dict( model=args.model, enable_layerwise_offload=args.enable_layerwise_offload, @@ -239,6 +251,7 @@ def main(): cache_backend=args.cache_backend, cache_config=cache_config, enable_diffusion_pipeline_profiler=args.enable_diffusion_pipeline_profiler, + kv_cache_dtype=kv_cache_dtype, ) if args.boundary_ratio is not None: omni_kwargs["boundary_ratio"] = args.boundary_ratio diff --git a/tests/diffusion/quantization/test_kv_quant.py b/tests/diffusion/quantization/test_kv_quant.py new file mode 100644 index 00000000000..f793e656a15 --- /dev/null +++ b/tests/diffusion/quantization/test_kv_quant.py @@ -0,0 +1,233 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for FP8 Q/K/V quantization utilities.""" + +import pytest +import torch + +pytestmark = [pytest.mark.core_model, pytest.mark.diffusion] + + +def test_qkv_roundtrip_preserves_values(): + """quantize_qkv_fp8 -> dequantize_fp8 should preserve values within FP8 tolerance.""" + from vllm_omni.quantization.kv_quant import ( + dequantize_fp8, + quantize_qkv_fp8, + ) + + torch.manual_seed(42) + query = torch.randn(2, 128, 8, 64, dtype=torch.bfloat16) + key = torch.randn(2, 128, 8, 64, dtype=torch.bfloat16) + value = torch.randn(2, 128, 8, 64, dtype=torch.bfloat16) + + fp8_q, fp8_k, fp8_v, q_scale, k_scale, v_scale = quantize_qkv_fp8( + query, key, value + ) + + assert fp8_q.dtype == torch.float8_e4m3fn + assert fp8_k.dtype == torch.float8_e4m3fn + assert fp8_v.dtype == torch.float8_e4m3fn + assert q_scale.numel() == 1 + assert k_scale.numel() == 1 + assert v_scale.numel() == 1 + + query_rt = dequantize_fp8(fp8_q, q_scale, torch.bfloat16) + key_rt = dequantize_fp8(fp8_k, k_scale, torch.bfloat16) + value_rt = dequantize_fp8(fp8_v, v_scale, torch.bfloat16) + + # FP8 e4m3 has ~0.1% relative error for typical values + torch.testing.assert_close(query_rt, query, rtol=0.05, atol=0.05) + torch.testing.assert_close(key_rt, key, rtol=0.05, atol=0.05) + torch.testing.assert_close(value_rt, value, rtol=0.05, atol=0.05) + + +def test_kv_only_roundtrip(): + """quantize_kv_fp8 for joint attention path.""" + from vllm_omni.quantization.kv_quant import ( + dequantize_fp8, + quantize_kv_fp8, + ) + + torch.manual_seed(42) + key = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + value = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + + fp8_k, fp8_v, k_scale, v_scale = quantize_kv_fp8(key, value) + + assert fp8_k.dtype == torch.float8_e4m3fn + assert k_scale > 0 + assert v_scale > 0 + + key_rt = dequantize_fp8(fp8_k, k_scale, torch.bfloat16) + torch.testing.assert_close(key_rt, key, rtol=0.05, atol=0.05) + + +def test_scales_are_positive(): + from vllm_omni.quantization.kv_quant import quantize_qkv_fp8 + + q = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + k = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + v = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + + _, _, _, q_scale, k_scale, v_scale = quantize_qkv_fp8(q, k, v) + assert q_scale > 0 + assert k_scale > 0 + assert v_scale > 0 + + +def test_zero_tensor(): + """All-zero input should not produce NaN or Inf.""" + from vllm_omni.quantization.kv_quant import ( + dequantize_fp8, + quantize_qkv_fp8, + ) + + q = torch.zeros(1, 16, 4, 32, dtype=torch.bfloat16) + k = torch.zeros(1, 16, 4, 32, dtype=torch.bfloat16) + v = torch.zeros(1, 16, 4, 32, dtype=torch.bfloat16) + + fp8_q, fp8_k, fp8_v, q_s, k_s, v_s = quantize_qkv_fp8(q, k, v) + q_rt = dequantize_fp8(fp8_q, q_s, torch.bfloat16) + k_rt = dequantize_fp8(fp8_k, k_s, torch.bfloat16) + + assert not torch.isnan(q_rt).any() + assert not torch.isnan(k_rt).any() + assert torch.allclose(q_rt, q) + assert torch.allclose(k_rt, k) + + +def test_fp16_input(): + """Should work with float16 input as well.""" + from vllm_omni.quantization.kv_quant import quantize_qkv_fp8 + + q = torch.randn(1, 32, 4, 64, dtype=torch.float16) + k = torch.randn(1, 32, 4, 64, dtype=torch.float16) + v = torch.randn(1, 32, 4, 64, dtype=torch.float16) + + fp8_q, fp8_k, fp8_v, _, _, _ = quantize_qkv_fp8(q, k, v) + assert fp8_q.dtype == torch.float8_e4m3fn + assert fp8_k.dtype == torch.float8_e4m3fn + assert fp8_v.dtype == torch.float8_e4m3fn + + +def test_kv_cache_dtype_config_field(): + """OmniDiffusionConfig should accept kv_cache_dtype field.""" + from vllm_omni.diffusion.data import OmniDiffusionConfig + + config = OmniDiffusionConfig(model="test", kv_cache_dtype="fp8") + assert config.kv_cache_dtype == "fp8" + + config_default = OmniDiffusionConfig(model="test") + assert config_default.kv_cache_dtype is None + + +def test_is_quantized_kv_cache(): + """is_quantized_kv_cache should detect FP8 dtype strings.""" + from vllm_omni.quantization.kv_quant import is_quantized_kv_cache + + assert is_quantized_kv_cache("fp8") is True + assert is_quantized_kv_cache("fp8_e4m3") is True + assert is_quantized_kv_cache(None) is False + assert is_quantized_kv_cache("auto") is False + assert is_quantized_kv_cache("bfloat16") is False + + +def test_attention_metadata_kv_cache_dtype(): + """AttentionMetadata should have kv_cache_dtype field.""" + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + + meta = AttentionMetadata() + assert meta.kv_cache_dtype is None + + meta.kv_cache_dtype = "fp8" + assert meta.kv_cache_dtype == "fp8" + + +def test_fast_qkv_quantization(): + """quantize_qkv_fp8_fast should use scale=1.0 (direct cast).""" + from vllm_omni.quantization.kv_quant import quantize_qkv_fp8_fast + + q = torch.randn(1, 32, 4, 64, dtype=torch.bfloat16) + k = torch.randn(1, 32, 4, 64, dtype=torch.bfloat16) + v = torch.randn(1, 32, 4, 64, dtype=torch.bfloat16) + + fp8_q, fp8_k, fp8_v, q_s, k_s, v_s = quantize_qkv_fp8_fast(q, k, v) + + assert fp8_q.dtype == torch.float8_e4m3fn + assert fp8_k.dtype == torch.float8_e4m3fn + assert fp8_v.dtype == torch.float8_e4m3fn + # Fast path uses scale=1.0 + assert q_s.item() == 1.0 + assert k_s.item() == 1.0 + assert v_s.item() == 1.0 + + +def test_fast_kv_quantization(): + """quantize_kv_fp8_fast for joint attention path.""" + from vllm_omni.quantization.kv_quant import quantize_kv_fp8_fast + + k = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + v = torch.randn(1, 64, 4, 32, dtype=torch.bfloat16) + + fp8_k, fp8_v, k_s, v_s = quantize_kv_fp8_fast(k, v) + + assert fp8_k.dtype == torch.float8_e4m3fn + assert fp8_v.dtype == torch.float8_e4m3fn + assert k_s.item() == 1.0 + assert v_s.item() == 1.0 + + +def test_flash_backend_supports_kv_cache_dtype(): + """FlashAttentionBackend should declare FP8 support.""" + from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionBackend + + assert FlashAttentionBackend.supports_kv_cache_dtype(None) is True + assert FlashAttentionBackend.supports_kv_cache_dtype("fp8") is True + assert FlashAttentionBackend.supports_kv_cache_dtype("fp8_e4m3") is True + assert FlashAttentionBackend.supports_kv_cache_dtype("mxfp8") is False + + +def test_sdpa_backend_does_not_support_fp8(): + """SDPABackend should not declare FP8 support.""" + from vllm_omni.diffusion.attention.backends.sdpa import SDPABackend + + assert SDPABackend.supports_kv_cache_dtype(None) is True + assert SDPABackend.supports_kv_cache_dtype("fp8") is False + + +def test_handle_kv_cache_dtype_clears_unsupported(): + """_handle_kv_cache_dtype should clear unsupported dtype to None.""" + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + from vllm_omni.diffusion.attention.backends.sdpa import SDPAImpl + + impl = SDPAImpl(num_heads=4, head_size=64, softmax_scale=0.125) + meta = AttentionMetadata(kv_cache_dtype="fp8") + + # SDPA has empty _supported_kv_cache_dtypes, should clear fp8 + impl._handle_kv_cache_dtype(meta, "cuda") + assert meta.kv_cache_dtype is None + + +def test_handle_kv_cache_dtype_preserves_supported(): + """_handle_kv_cache_dtype should preserve supported dtype.""" + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl + + impl = FlashAttentionImpl(num_heads=4, head_size=64, softmax_scale=0.125) + meta = AttentionMetadata(kv_cache_dtype="fp8") + + impl._handle_kv_cache_dtype(meta, "cuda") + assert meta.kv_cache_dtype == "fp8" + + +def test_handle_kv_cache_dtype_clears_unsupported_platform(): + """FP8 on FlashAttention should be cleared for non-CUDA platforms.""" + from vllm_omni.diffusion.attention.backends.abstract import AttentionMetadata + from vllm_omni.diffusion.attention.backends.flash_attn import FlashAttentionImpl + + impl = FlashAttentionImpl(num_heads=4, head_size=64, softmax_scale=0.125) + meta = AttentionMetadata(kv_cache_dtype="fp8") + + # NPU not in FlashAttentionImpl._supported_kv_cache_dtypes + impl._handle_kv_cache_dtype(meta, "npu") + assert meta.kv_cache_dtype is None diff --git a/vllm_omni/diffusion/attention/backends/abstract.py b/vllm_omni/diffusion/attention/backends/abstract.py index 472fde422d5..b4a269d1660 100644 --- a/vllm_omni/diffusion/attention/backends/abstract.py +++ b/vllm_omni/diffusion/attention/backends/abstract.py @@ -6,9 +6,12 @@ from typing import Generic, TypeVar import torch +from vllm.logger import init_logger from vllm_omni.platforms import current_omni_platform +logger = init_logger(__name__) + class AttentionBackend(ABC): """Abstract class for diffusion attention backends.""" @@ -19,6 +22,15 @@ class AttentionBackend(ABC): def supports_attention_mask(cls) -> bool: return False + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: str | None) -> bool: + """Whether this backend supports the given KV cache quantization dtype. + + Override in subclasses that support quantized KV (e.g. FP8). + Default: only None (no quantization) is supported. + """ + return kv_cache_dtype is None + @staticmethod @abstractmethod def get_name() -> str: @@ -65,11 +77,25 @@ class AttentionMetadata: joint_strategy: str = "front" # the strategy to joint the query, key, and value, can be "front" or "rear" + # KV cache dtype for quantization (e.g. "fp8"). Each backend decides + # whether and how to quantize Q/K/V based on this field. + kv_cache_dtype: str | None = None + T = TypeVar("T", bound=AttentionMetadata) class AttentionImpl(ABC, Generic[T]): + + # Per-platform kv_cache_dtype support. Maps OmniPlatformEnum value + # (e.g. "cuda", "npu") to the set of quantized dtypes that platform + # handles. The base forward() checks this before dispatching and + # clears unsupported dtypes with a warning. + # + # To add FP8 support for a new platform in a subclass: + # _supported_kv_cache_dtypes = {"cuda": {"fp8"}, "npu": {"fp8"}} + _supported_kv_cache_dtypes: dict[str, set[str]] = {} + @abstractmethod def __init__( self, @@ -83,6 +109,42 @@ def __init__( ) -> None: raise NotImplementedError + # Platform enum value → forward method name. To add a new platform, + # implement forward_{name}() and register it here. + _PLATFORM_DISPATCH: dict[str, str] = { + "cuda": "forward_cuda", + "rocm": "forward_hip", + "npu": "forward_npu", + "xpu": "forward_xpu", + "musa": "forward_musa", + } + + def _handle_kv_cache_dtype( + self, + attn_metadata: T | None, + platform_key: str, + ) -> None: + """Check kv_cache_dtype compatibility for this platform. + + If the requested kv_cache_dtype is not in _supported_kv_cache_dtypes + for the current platform, it is cleared to None with a warning. + """ + if attn_metadata is None: + return + kv_cache_dtype = attn_metadata.kv_cache_dtype + if kv_cache_dtype is None: + return + supported = self._supported_kv_cache_dtypes.get(platform_key, set()) + if kv_cache_dtype not in supported: + logger.warning_once( + "kv_cache_dtype='%s' requested but %s on %s does not support " + "it. Running in native dtype.", + kv_cache_dtype, + type(self).__name__, + platform_key, + ) + attn_metadata.kv_cache_dtype = None + def forward( self, query: torch.Tensor, @@ -91,18 +153,16 @@ def forward( attn_metadata: T | None = None, ) -> torch.Tensor: """Dispatch to platform-specific forward implementation.""" - if current_omni_platform.is_rocm(): - return self.forward_hip(query, key, value, attn_metadata) - elif current_omni_platform.is_cuda(): - return self.forward_cuda(query, key, value, attn_metadata) - elif current_omni_platform.is_npu(): - return self.forward_npu(query, key, value, attn_metadata) - elif current_omni_platform.is_xpu(): - return self.forward_xpu(query, key, value, attn_metadata) - elif current_omni_platform.is_musa(): - return self.forward_musa(query, key, value, attn_metadata) - else: - raise NotImplementedError(f"No forward implementation for platform: {current_omni_platform}") + platform_key = current_omni_platform.device_name + method_name = self._PLATFORM_DISPATCH.get(platform_key) + if method_name is None: + raise NotImplementedError( + f"No forward implementation for platform: {platform_key}. " + f"Register it in AttentionImpl._PLATFORM_DISPATCH." + ) + self._handle_kv_cache_dtype(attn_metadata, platform_key) + method = getattr(self, method_name) + return method(query, key, value, attn_metadata) def forward_cuda( self, diff --git a/vllm_omni/diffusion/attention/backends/flash_attn.py b/vllm_omni/diffusion/attention/backends/flash_attn.py index 5c586c0631e..8495862b85c 100644 --- a/vllm_omni/diffusion/attention/backends/flash_attn.py +++ b/vllm_omni/diffusion/attention/backends/flash_attn.py @@ -20,6 +20,14 @@ class FlashAttentionBackend(AttentionBackend): def supports_attention_mask(cls) -> bool: return True + @classmethod + def supports_kv_cache_dtype(cls, kv_cache_dtype: str | None) -> bool: + if kv_cache_dtype is None: + return True + from vllm_omni.quantization.kv_quant import is_quantized_kv_cache + + return is_quantized_kv_cache(kv_cache_dtype) + @staticmethod def get_supported_head_sizes() -> list[int]: return [64, 96, 128, 192, 256] @@ -34,6 +42,15 @@ def get_impl_cls() -> type["FlashAttentionImpl"]: class FlashAttentionImpl(AttentionImpl): + # Per-platform FP8 KV quantization support. + # To enable FP8 on a new platform, add its OmniPlatformEnum value here + # and handle kv_cache_dtype in the corresponding forward_{platform}(). + _supported_kv_cache_dtypes = { + "cuda": {"fp8", "fp8_e4m3"}, + # "rocm": {"fp8", "fp8_e4m3"}, + # "npu": {"fp8"}, + } + def __init__( self, num_heads: int, @@ -59,6 +76,9 @@ def _forward_varlen_masked( key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, + q_descale: torch.Tensor | None = None, + k_descale: torch.Tensor | None = None, + v_descale: torch.Tensor | None = None, ) -> torch.Tensor: from vllm_omni.diffusion.attention.backends.utils.fa import ( _pad_input, @@ -73,6 +93,15 @@ def _forward_varlen_masked( query, key, value, attention_mask, query_length, _unpad_input ) + varlen_kwargs: dict = { + "causal": self.causal, + "softmax_scale": self.softmax_scale, + } + if q_descale is not None: + varlen_kwargs["q_descale"] = q_descale + varlen_kwargs["k_descale"] = k_descale + varlen_kwargs["v_descale"] = v_descale + out_unpad = flash_attn_varlen_func( q, k, @@ -81,10 +110,7 @@ def _forward_varlen_masked( cu_seqlens_k=cu_seq_lens_k, max_seqlen_q=max_length_q, max_seqlen_k=max_length_k, - **{ - "causal": self.causal, - "softmax_scale": self.softmax_scale, - }, + **varlen_kwargs, ) out_unpad = self._unwrap_flash_output(out_unpad) return _pad_input(out_unpad, indices_q, query.size(0), query_length) @@ -97,6 +123,11 @@ def forward_cuda( attn_metadata: AttentionMetadata = None, ) -> torch.Tensor: """CUDA/ROCm flash attention implementation.""" + from vllm_omni.quantization.kv_quant import is_quantized_kv_cache + + kv_cache_dtype = attn_metadata.kv_cache_dtype if attn_metadata else None + if is_quantized_kv_cache(kv_cache_dtype): + return self._forward_fp8(query, key, value, attn_metadata) from vllm_omni.diffusion.attention.backends.utils.fa import ( HAS_FLASH_ATTN, flash_attn_func, @@ -209,3 +240,68 @@ def forward_npu( layout="BNSD", ) return output + + @staticmethod + def _reshape_descale(scale: torch.Tensor, batch: int, num_heads_k: int) -> torch.Tensor: + """Reshape per-tensor scale to FA3's expected (batch, num_heads_k) shape.""" + return scale.view(1, 1).expand(batch, num_heads_k).contiguous() + + def _forward_fp8( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """FP8 attention: quantize Q/K/V here, then use FA3 native or BF16 fallback. + + Quantization is owned by the backend so that: + 1. Non-FP8 backends (SDPA) never pay the quant/dequant cost. + 2. Each platform can plug in its own FP8 conversion logic. + """ + from vllm_omni.quantization.kv_quant import ( + quantize_kv_fp8_fast, + quantize_qkv_fp8_fast, + ) + + # Quantize Q/K/V using fast saturating cast (scale=1.0) + fp8_q, fp8_k, fp8_v, q_scale, k_scale, v_scale = quantize_qkv_fp8_fast( + query, key, value + ) + + # Also quantize joint K/V if present + if attn_metadata.joint_key is not None and attn_metadata.joint_value is not None: + jk, jv, _, _ = quantize_kv_fp8_fast( + attn_metadata.joint_key, attn_metadata.joint_value + ) + attn_metadata.joint_key = jk + attn_metadata.joint_value = jv + + B, S, H, D = key.shape + q_descale = self._reshape_descale(q_scale, B, H) + k_descale = self._reshape_descale(k_scale, B, H) + v_descale = self._reshape_descale(v_scale, B, H) + + # Primary path: FA3 native FP8 + from vllm_omni.diffusion.attention.backends.ring.ring_globals import ( + HAS_FA3, + fa3_attn_func, + ) + + if HAS_FA3 and fa3_attn_func is not None: + out = fa3_attn_func( + fp8_q, fp8_k, fp8_v, + softmax_scale=self.softmax_scale, + causal=self.causal, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + return self._unwrap_flash_output(out) + + # Fallback: no FA3, run standard BF16 path + logger.warning_once( + "No FA3 available for FP8 attention. Running in BF16." + ) + attn_metadata.kv_cache_dtype = None + return self.forward_cuda(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/attention/backends/sdpa.py b/vllm_omni/diffusion/attention/backends/sdpa.py index 3585689dd27..7222306d71d 100644 --- a/vllm_omni/diffusion/attention/backends/sdpa.py +++ b/vllm_omni/diffusion/attention/backends/sdpa.py @@ -97,6 +97,9 @@ def _forward_impl( attn_metadata: AttentionMetadata | None = None, mask_mode: SDPAMaskMode = "broadcast_k", ) -> torch.Tensor: + # Note: unsupported kv_cache_dtype is already warned and cleared + # by AttentionImpl._handle_kv_cache_dtype() in the base forward(). + # Normalize mask before permuting q/k/v. # _maybe_reshape_attn_mask expects sequence length on dim=1. attention_mask = None diff --git a/vllm_omni/diffusion/attention/layer.py b/vllm_omni/diffusion/attention/layer.py index 4fdf2ff1612..9bf2cd84d72 100644 --- a/vllm_omni/diffusion/attention/layer.py +++ b/vllm_omni/diffusion/attention/layer.py @@ -93,6 +93,11 @@ def __init__( # Fallback strategy when SP is not active (outside sharded regions) self._no_parallel_strategy = NoParallelAttention() + # KV cache quantization: resolved lazily in forward() because + # forward_context is not available during model loading. + self._kv_cache_dtype: str | None = None + self._kv_cache_dtype_resolved: bool = False + def _get_active_parallel_strategy(self): """Get the parallel strategy based on current SP active state. @@ -108,6 +113,34 @@ def _get_active_parallel_strategy(self): return self._no_parallel_strategy return self.parallel_strategy + def _resolve_kv_cache_dtype(self) -> str | None: + """Lazily resolve kv_cache_dtype from forward context.""" + if self._kv_cache_dtype_resolved: + return self._kv_cache_dtype + try: + config = get_forward_context().omni_diffusion_config + dtype = config.kv_cache_dtype + except Exception: + dtype = None + if dtype: + if not self.attn_backend.supports_kv_cache_dtype(dtype): + logger.warning( + "Attention backend %s does not support kv_cache_dtype='%s'. " + "KV quantization will be disabled.", + self.attn_backend.get_name(), + dtype, + ) + dtype = None + elif self.use_ring: + raise ValueError( + "FP8 KV quantization is not compatible with ring attention " + "(ring_degree > 1). Ring kernels do not propagate FP8 descale " + "factors. Use Ulysses SP instead." + ) + self._kv_cache_dtype = dtype + self._kv_cache_dtype_resolved = True + return dtype + def forward( self, query: torch.Tensor, @@ -123,6 +156,13 @@ def forward( # For Ring: Concat joint_q query, key, value, attn_metadata, ctx = strategy.pre_attention(query, key, value, attn_metadata) + # Signal KV quantization to backends via metadata + kv_cache_dtype = self._resolve_kv_cache_dtype() + if kv_cache_dtype: + if attn_metadata is None: + attn_metadata = AttentionMetadata() + attn_metadata.kv_cache_dtype = kv_cache_dtype + # 2. Kernel Execution (Computation) if self.use_ring and strategy is not self._no_parallel_strategy: out = self._run_ring_attention(query, key, value, attn_metadata) diff --git a/vllm_omni/diffusion/data.py b/vllm_omni/diffusion/data.py index 3071fd9d56a..a6290461649 100644 --- a/vllm_omni/diffusion/data.py +++ b/vllm_omni/diffusion/data.py @@ -498,6 +498,13 @@ class OmniDiffusionConfig: # Per-component: {"transformer": {"method": "fp8"}, "vae": None} quantization_config: str | QuantizationConfig | dict[str, Any] | None = None + # KV cache dtype for attention. Aligned with upstream vLLM's --kv-cache-dtype. + # None = native dtype (no quantization). + # "fp8" = dynamic FP8 (float8_e4m3fn) quantization per forward pass. + # On Hopper+FA3: native FP8 attention (memory + compute savings). + # On other backends: no benefit, backends skip quantization. + kv_cache_dtype: str | None = None + # Diffusion pipeline Profiling config enable_diffusion_pipeline_profiler: bool = False diff --git a/vllm_omni/engine/async_omni_engine.py b/vllm_omni/engine/async_omni_engine.py index 8cd2d695268..386879b90c9 100644 --- a/vllm_omni/engine/async_omni_engine.py +++ b/vllm_omni/engine/async_omni_engine.py @@ -940,6 +940,7 @@ def _create_default_diffusion_stage_cfg(kwargs: dict[str, Any]) -> list: "enable_multithread_weight_load": kwargs.get("enable_multithread_weight_load", True), "num_weight_load_threads": kwargs.get("num_weight_load_threads", 4), "quantization": kwargs.get("quantization", None), + "kv_cache_dtype": kwargs.get("kv_cache_dtype", None), "enable_diffusion_pipeline_profiler": kwargs.get("enable_diffusion_pipeline_profiler", False), **( { diff --git a/vllm_omni/platforms/cuda/platform.py b/vllm_omni/platforms/cuda/platform.py index 6bf740a0188..de1454632a4 100644 --- a/vllm_omni/platforms/cuda/platform.py +++ b/vllm_omni/platforms/cuda/platform.py @@ -73,6 +73,17 @@ def get_diffusion_attn_backend_cls( logger.info("Using diffusion attention backend '%s'", backend_upper) return backend.get_path() + # Prefer SageAttention (INT8 QK + FP8 PV on Hopper) when available + try: + import sageattention # noqa: F401 + sage_available = True + except ImportError: + sage_available = False + + if sage_available: + logger.info("Defaulting to diffusion attention backend SAGE_ATTN") + return DiffusionAttentionBackendEnum.SAGE_ATTN.get_path() + if flash_attn_supported: logger.info("Defaulting to diffusion attention backend FLASH_ATTN") return DiffusionAttentionBackendEnum.FLASH_ATTN.get_path() diff --git a/vllm_omni/quantization/kv_quant.py b/vllm_omni/quantization/kv_quant.py new file mode 100644 index 00000000000..73b83cd74a7 --- /dev/null +++ b/vllm_omni/quantization/kv_quant.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""FP8 quantization utilities for diffusion attention tensors. + +Provides per-tensor dynamic quantization of Q/K/V tensors to +float8_e4m3fn format. Designed for diffusion models where Q/K/V are +computed fresh each forward pass (no persistent KV cache). + +Supports two modes: + - Dynamic: computes amax per call (accurate but ~4ms overhead at 50K tokens) + - Static (delayed scaling): reuses a cached scale from the previous call, + skipping the expensive amax reduction (~0.5ms overhead). +""" + +import torch +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def is_quantized_kv_cache(kv_cache_dtype: str | None) -> bool: + """Check if the KV cache dtype implies quantized storage.""" + return kv_cache_dtype in ("fp8", "fp8_e4m3") + + +# Try to use vLLM's fused CUDA kernel for quantization. +# Falls back to device-agnostic PyTorch ops (works on any platform). +try: + from vllm._custom_ops import scaled_fp8_quant as _vllm_scaled_fp8_quant + + _HAS_FUSED_QUANT = True +except ImportError: + _HAS_FUSED_QUANT = False + + +def _quantize_tensor_fp8( + tensor: torch.Tensor, + cached_scale: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a single tensor to FP8 with per-tensor scaling. + + Args: + tensor: Input tensor in BF16/FP16. + cached_scale: If provided, use this scale (static mode, skips amax). + If None, compute scale dynamically. + + Returns: + ``(fp8_tensor, inv_scale)`` where inv_scale is the dequant scale. + """ + if _HAS_FUSED_QUANT and tensor.is_cuda: + orig_shape = tensor.shape + flat = tensor.reshape(-1, orig_shape[-1]) + # Pass cached_scale for static quant (no amax), None for dynamic + fp8_flat, scale = _vllm_scaled_fp8_quant(flat, scale=cached_scale) + fp8_out = fp8_flat.reshape(orig_shape) + return fp8_out, scale + else: + finfo = torch.finfo(torch.float8_e4m3fn) + if cached_scale is not None: + # Static: use cached scale directly + inv_scale = cached_scale + scale_factor = 1.0 / inv_scale + fp8 = (tensor * scale_factor).clamp(finfo.min, finfo.max).to( + torch.float8_e4m3fn + ) + return fp8, inv_scale + else: + # Dynamic: compute amax + amax = tensor.abs().amax().clamp(min=1e-12) + scale_factor = finfo.max / amax + fp8 = (tensor * scale_factor).clamp(finfo.min, finfo.max).to( + torch.float8_e4m3fn + ) + inv_scale = amax / finfo.max + return fp8, inv_scale + + +def quantize_qkv_fp8( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cached_scales: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Quantize Q/K/V tensors to float8_e4m3fn. + + Args: + query: Query tensor in BF16/FP16, shape ``(B, S, H, D)`` + key: Key tensor in BF16/FP16, shape ``(B, S, H, D)`` + value: Value tensor in BF16/FP16, shape ``(B, S, H, D)`` + cached_scales: Optional ``(q_scale, k_scale, v_scale)`` from a + previous call. When provided, skips the expensive amax + reduction (static/delayed scaling mode). + + Returns: + ``(fp8_query, fp8_key, fp8_value, q_scale, k_scale, v_scale)`` + where scales are inverse (dequant) scales. + """ + if cached_scales is not None: + cq, ck, cv = cached_scales + else: + cq = ck = cv = None + fp8_q, q_scale = _quantize_tensor_fp8(query, cq) + fp8_k, k_scale = _quantize_tensor_fp8(key, ck) + fp8_v, v_scale = _quantize_tensor_fp8(value, cv) + return fp8_q, fp8_k, fp8_v, q_scale, k_scale, v_scale + + +def quantize_kv_fp8( + key: torch.Tensor, + value: torch.Tensor, + cached_scales: tuple[torch.Tensor, torch.Tensor] | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Quantize K/V tensors to float8_e4m3fn (joint attention path). + + Returns: + ``(fp8_key, fp8_value, k_scale, v_scale)`` + """ + if cached_scales is not None: + ck, cv = cached_scales + else: + ck = cv = None + fp8_k, k_scale = _quantize_tensor_fp8(key, ck) + fp8_v, v_scale = _quantize_tensor_fp8(value, cv) + return fp8_k, fp8_v, k_scale, v_scale + + +def dequantize_fp8( + tensor: torch.Tensor, + inv_scale: torch.Tensor, + output_dtype: torch.dtype, +) -> torch.Tensor: + """Dequantize an FP8 tensor back to the given dtype. + + Args: + tensor: FP8-quantized tensor (float8_e4m3fn). + inv_scale: Inverse scale (dequant scale). + output_dtype: Target dtype (e.g. ``torch.bfloat16``). + + Returns: + Dequantized tensor: ``tensor.to(output_dtype) * inv_scale``. + """ + return (tensor.to(output_dtype) * inv_scale).to(output_dtype) + + +def quantize_qkv_fp8_fast( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor, torch.Tensor, torch.Tensor]: + """Ultra-fast FP8 quantization using direct saturating cast (no amax). + + For diffusion attention where Q/K/V values are typically in [-10, 10], + well within float8_e4m3fn range (±448). Eliminates the expensive + per-tensor amax reduction that dominates quantization overhead at + large sequence lengths (50K+ tokens). + + Scale is fixed at 1.0 (identity), so descale is also 1.0. + """ + one = torch.ones(1, dtype=torch.float32, device=query.device) + fp8_q = query.to(torch.float8_e4m3fn) + fp8_k = key.to(torch.float8_e4m3fn) + fp8_v = value.to(torch.float8_e4m3fn) + return fp8_q, fp8_k, fp8_v, one, one, one + + +def quantize_kv_fp8_fast( + key: torch.Tensor, + value: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Fast FP8 quantization for K/V only (joint attention path).""" + one = torch.ones(1, dtype=torch.float32, device=key.device) + fp8_k = key.to(torch.float8_e4m3fn) + fp8_v = value.to(torch.float8_e4m3fn) + return fp8_k, fp8_v, one, one