diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index de56cbac8474..0329d110244c 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: """Run MLA benchmark with appropriate backend.""" from mla_runner import run_mla_benchmark as run_mla - return run_mla(config.backend, config, **kwargs) + return run_mla( + config.backend, config, prefill_backend=config.prefill_backend, **kwargs + ) def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: @@ -440,14 +442,21 @@ def main(): # Backend selection parser.add_argument( "--backends", + "--decode-backends", nargs="+", - help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, " + help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, " "flashinfer_mla, flashattn_mla, flashmla)", ) parser.add_argument( "--backend", help="Single backend (alternative to --backends)", ) + parser.add_argument( + "--prefill-backends", + nargs="+", + help="Prefill backends to compare (fa2, fa3, fa4). " + "Uses the first decode backend for impl construction.", + ) # Batch specifications parser.add_argument( @@ -502,7 +511,7 @@ def main(): # Override args with YAML values, but CLI args take precedence # Check if CLI provided backends (they would be non-None and not default) - cli_backends_provided = args.backends is not None or args.backend is not None + cli_backends_provided = args.backend is not None or args.backends is not None # Backend(s) - only use YAML if CLI didn't specify if not cli_backends_provided: @@ -512,6 +521,12 @@ def main(): elif "backends" in yaml_config: args.backends = yaml_config["backends"] args.backend = None + elif "decode_backends" in yaml_config: + args.backends = yaml_config["decode_backends"] + args.backend = None + + # Prefill backends (e.g., ["fa3", "fa4"]) + args.prefill_backends = yaml_config.get("prefill_backends", None) # Check for special modes if "mode" in yaml_config: @@ -613,7 +628,10 @@ def main(): # Determine backends backends = args.backends or ([args.backend] if args.backend else ["flash"]) + prefill_backends = getattr(args, "prefill_backends", None) console.print(f"Backends: {', '.join(backends)}") + if prefill_backends: + console.print(f"Prefill backends: {', '.join(prefill_backends)}") console.print(f"Batch specs: {', '.join(args.batch_specs)}") console.print() @@ -850,37 +868,93 @@ def main(): else: # Normal mode: compare backends - total = len(backends) * len(args.batch_specs) + decode_results = [] + prefill_results = [] - with tqdm(total=total, desc="Benchmarking") as pbar: - for spec in args.batch_specs: - for backend in backends: - config = BenchmarkConfig( - backend=backend, - batch_spec=spec, - num_layers=args.num_layers, - head_dim=args.head_dim, - num_q_heads=args.num_q_heads, - num_kv_heads=args.num_kv_heads, - block_size=args.block_size, - device=args.device, - repeats=args.repeats, - warmup_iters=args.warmup_iters, - profile_memory=args.profile_memory, - ) + # Run decode backend comparison + if not prefill_backends: + # No prefill backends specified: compare decode backends as before + total = len(backends) * len(args.batch_specs) - result = run_benchmark(config) - all_results.append(result) + with tqdm(total=total, desc="Benchmarking") as pbar: + for spec in args.batch_specs: + for backend in backends: + config = BenchmarkConfig( + backend=backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + ) - if not result.success: - console.print(f"[red]Error {backend} {spec}: {result.error}[/]") + result = run_benchmark(config) + decode_results.append(result) - pbar.update(1) + if not result.success: + console.print( + f"[red]Error {backend} {spec}: {result.error}[/]" + ) - # Display results - console.print("\n[bold green]Results:[/]") - formatter = ResultsFormatter(console) - formatter.print_table(all_results, backends) + pbar.update(1) + + console.print("\n[bold green]Results:[/]") + formatter = ResultsFormatter(console) + formatter.print_table(decode_results, backends) + + # Run prefill backend comparison + if prefill_backends: + # Use first decode backend for impl construction + decode_backend = backends[0] + total = len(prefill_backends) * len(args.batch_specs) + + console.print( + f"[yellow]Prefill comparison mode: " + f"using {decode_backend} for decode impl[/]" + ) + + with tqdm(total=total, desc="Prefill benchmarking") as pbar: + for spec in args.batch_specs: + for pb in prefill_backends: + config = BenchmarkConfig( + backend=decode_backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + prefill_backend=pb, + ) + + result = run_benchmark(config) + + # Label result with prefill backend name for display + labeled_config = replace(result.config, backend=pb) + result = replace(result, config=labeled_config) + prefill_results.append(result) + + if not result.success: + console.print(f"[red]Error {pb} {spec}: {result.error}[/]") + + pbar.update(1) + + console.print("\n[bold green]Prefill Backend Results:[/]") + formatter = ResultsFormatter(console) + formatter.print_table( + prefill_results, prefill_backends, compare_to_fastest=True + ) + + all_results = decode_results + prefill_results # Save results if all_results: diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 9fa22c8d54f0..208d6273c928 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -77,6 +77,7 @@ def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int): self.qk_nope_head_dim = qk_nope_head_dim self.v_head_dim = v_head_dim self.out_dim = qk_nope_head_dim + v_head_dim + self.weight = torch.empty(0, dtype=torch.bfloat16) def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]: """ @@ -213,6 +214,7 @@ class BenchmarkConfig: use_cuda_graphs: bool = False # MLA-specific + prefill_backend: str | None = None kv_lora_rank: int | None = None qk_nope_head_dim: int | None = None qk_rope_head_dim: int | None = None diff --git a/benchmarks/attention_benchmarks/configs/mla_prefill.yaml b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml index ef6b2cb07dc7..122dbd783c5b 100644 --- a/benchmarks/attention_benchmarks/configs/mla_prefill.yaml +++ b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml @@ -1,4 +1,19 @@ -# MLA prefill-only benchmark configuration for sparse backends +# MLA prefill backend comparison +# +# Compares all available MLA prefill backends: +# FA backends: fa2, fa3, fa4 (FlashAttention versions) +# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer) +# +# Uses cutlass_mla as the decode backend for impl construction +# (only the prefill path is exercised). +# +# Backends that aren't available on the current platform will report errors +# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory). +# +# Usage: +# python benchmark.py --config configs/mla_prefill.yaml + +description: "MLA prefill backend comparison" model: name: "deepseek-v3" @@ -12,20 +27,25 @@ model: v_head_dim: 128 block_size: 128 -# Model parameter sweep: simulate tensor parallelism by varying num_q_heads -# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads -model_parameter_sweep: - param_name: "num_q_heads" - values: [128, 64, 32, 16] - label_format: "{backend}_{value}h" +# model: +# name: "deepseek-v2-lite" +# num_layers: 27 +# num_q_heads: 16 +# num_kv_heads: 1 +# head_dim: 576 +# kv_lora_rank: 512 +# qk_nope_head_dim: 128 +# qk_rope_head_dim: 64 +# v_head_dim: 128 +# block_size: 128 batch_specs: # Pure prefill - - "1q512" - - "1q1k" - - "1q2k" - - "1q4k" - - "1q8k" + - "q512" + - "q1k" + - "q2k" + - "q4k" + - "q8k" # Batched pure prefill - "2q512" @@ -44,19 +64,63 @@ batch_specs: - "8q4k" - "8q8k" - # Extend - - "1q512s4k" - - "1q512s8k" - - "1q1ks8k" - - "1q2ks8k" - - "1q2ks16k" - - "1q4ks16k" + # Chunked prefill / extend + # Short context + - "q128s1k" + - "q256s2k" + - "q512s4k" + - "q1ks4k" + - "q2ks8k" + - "2q128s1k" + - "2q256s2k" + - "2q512s4k" + - "2q1ks4k" + - "2q2ks8k" + - "4q128s1k" + - "4q256s2k" + - "4q512s4k" + - "4q1ks4k" + - "4q2ks8k" + - "8q128s1k" + - "8q256s2k" + - "8q512s4k" + - "8q1ks4k" + + # Medium context + - "q128s16k" + - "q512s16k" + - "q1ks16k" + - "q2ks16k" + - "2q128s16k" + - "2q512s16k" + - "2q1ks16k" + - "2q2ks16k" + - "4q128s16k" + - "4q512s16k" + - "4q1ks16k" + - "4q2ks16k" + + # Long context + - "q128s64k" + - "q512s64k" + - "q1ks64k" + - "q2ks64k" + - "2q128s64k" + - "2q512s64k" + - "2q1ks64k" + - "2q2ks64k" + +decode_backends: + - CUTLASS_MLA -backends: - - FLASHMLA_SPARSE - - FLASHINFER_MLA_SPARSE +prefill_backends: + - fa2 + - fa3 + - fa4 + - flashinfer + - cudnn + - trtllm device: "cuda:0" -repeats: 10 -warmup_iters: 3 -profile_memory: true +repeats: 20 +warmup_iters: 5 diff --git a/benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml b/benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml new file mode 100644 index 000000000000..ef6b2cb07dc7 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml @@ -0,0 +1,62 @@ +# MLA prefill-only benchmark configuration for sparse backends + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Pure prefill + - "1q512" + - "1q1k" + - "1q2k" + - "1q4k" + - "1q8k" + + # Batched pure prefill + - "2q512" + - "2q1k" + - "2q2k" + - "2q4k" + - "2q8k" + - "4q512" + - "4q1k" + - "4q2k" + - "4q4k" + - "4q8k" + - "8q512" + - "8q1k" + - "8q2k" + - "8q4k" + - "8q8k" + + # Extend + - "1q512s4k" + - "1q512s8k" + - "1q1ks8k" + - "1q2ks8k" + - "1q2ks16k" + - "1q4ks16k" + +backends: + - FLASHMLA_SPARSE + - FLASHINFER_MLA_SPARSE + +device: "cuda:0" +repeats: 10 +warmup_iters: 3 +profile_memory: true diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index 3c1ca4b3dade..0d612e374a12 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -62,6 +62,7 @@ def create_minimal_vllm_config( max_num_seqs: int = 256, mla_dims: dict | None = None, index_topk: int | None = None, + prefill_backend: str | None = None, ) -> VllmConfig: """ Create minimal VllmConfig for MLA benchmarks. @@ -75,6 +76,9 @@ def create_minimal_vllm_config( setup_mla_dims(model_name) index_topk: Optional topk value for sparse MLA backends. If provided, the config will include index_topk for sparse attention. + prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer", + "cudnn", "trtllm"). Configures the attention config to + force the specified prefill backend. Returns: VllmConfig for benchmarking @@ -163,7 +167,7 @@ def create_minimal_vllm_config( compilation_config = CompilationConfig() - return VllmConfig( + vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -171,9 +175,84 @@ def create_minimal_vllm_config( compilation_config=compilation_config, ) + if prefill_backend is not None: + prefill_cfg = get_prefill_backend_config(prefill_backend) + if prefill_cfg["flash_attn_version"] is not None: + vllm_config.attention_config.flash_attn_version = prefill_cfg[ + "flash_attn_version" + ] + vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[ + "disable_flashinfer_prefill" + ] + vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[ + "use_cudnn_prefill" + ] + vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[ + "use_trtllm_ragged_deepseek_prefill" + ] + + return vllm_config + # ============================================================================ -# Backend Configuration +# Prefill Backend Configuration +# ============================================================================ + +# Maps prefill backend names to attention config overrides. +# FA backends set flash_attn_version and disable non-FA paths. +# Non-FA backends enable their specific path and disable others. +_PREFILL_BACKEND_CONFIG: dict[str, dict] = { + "fa2": { + "flash_attn_version": 2, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "fa3": { + "flash_attn_version": 3, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "fa4": { + "flash_attn_version": 4, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "flashinfer": { + "flash_attn_version": None, + "disable_flashinfer_prefill": False, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "cudnn": { + "flash_attn_version": None, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": True, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "trtllm": { + "flash_attn_version": None, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": True, + }, +} + + +def get_prefill_backend_config(prefill_backend: str) -> dict: + """Get attention config overrides for a prefill backend.""" + if prefill_backend not in _PREFILL_BACKEND_CONFIG: + raise ValueError( + f"Unknown prefill backend: {prefill_backend!r}. " + f"Available: {list(_PREFILL_BACKEND_CONFIG.keys())}" + ) + return _PREFILL_BACKEND_CONFIG[prefill_backend] + + +# ============================================================================ +# Decode Backend Configuration # ============================================================================ @@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict: Returns: Dict with backend configuration """ + from vllm.v1.attention.backend import MultipleOf from vllm.v1.attention.backends.registry import AttentionBackendEnum try: @@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict: block_sizes = backend_class.get_supported_kernel_block_sizes() # Use first supported block size (backends typically support one for MLA) block_size = block_sizes[0] if block_sizes else None - if hasattr(block_size, "value"): - # Handle MultipleOf enum + if isinstance(block_size, MultipleOf): + # No fixed block size; fall back to config value block_size = None # Check if sparse via class method if available @@ -676,16 +756,11 @@ def _run_single_benchmark( if is_sparse and indexer is not None: indexer.fill_random_indices(total_q, max_kv_len) - # Determine which forward method to use - if is_sparse: - # Sparse backends use forward_mqa + # Determine which forward method to use based on metadata + if metadata.decode is not None: forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) - elif metadata.decode is not None: - forward_fn = lambda: impl._forward_decode( - decode_inputs, kv_cache, metadata, layer - ) elif metadata.prefill is not None: - forward_fn = lambda: impl._forward_prefill( + forward_fn = lambda: impl.forward_mha( prefill_inputs["q"], prefill_inputs["k_c_normed"], prefill_inputs["k_pe"], @@ -732,6 +807,7 @@ def _run_mla_benchmark_batched( backend: str, configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] index_topk: int = 2048, + prefill_backend: str | None = None, ) -> list[BenchmarkResult]: """ Unified batched MLA benchmark runner for all backends. @@ -743,11 +819,13 @@ def _run_mla_benchmark_batched( to avoid setup/teardown overhead. Args: - backend: Backend name + backend: Backend name (decode backend used for impl construction) configs_with_params: List of (config, threshold, num_splits) tuples - threshold: reorder_batch_threshold (FlashAttn/FlashMLA only) - num_splits: num_kv_splits (CUTLASS only) index_topk: Topk value for sparse MLA backends (default 2048) + prefill_backend: Prefill backend name (e.g., "fa3", "fa4"). + When set, forces the specified FlashAttention version for prefill. Returns: List of BenchmarkResult objects @@ -780,11 +858,25 @@ def _run_mla_benchmark_batched( block_size=block_size, mla_dims=mla_dims, # Use custom dims from config or default index_topk=index_topk if is_sparse else None, + prefill_backend=prefill_backend, ) results = [] with set_current_vllm_config(vllm_config): + # Clear cached prefill backend detection functions so they re-evaluate + # with the current VllmConfig. These are @functools.cache decorated and + # would otherwise return stale results from a previous backend's config. + from vllm.model_executor.layers.attention.mla_attention import ( + use_cudnn_prefill, + use_flashinfer_prefill, + use_trtllm_ragged_deepseek_prefill, + ) + + use_flashinfer_prefill.cache_clear() + use_cudnn_prefill.cache_clear() + use_trtllm_ragged_deepseek_prefill.cache_clear() + # Create backend impl, layer, builder, and indexer (reused across benchmarks) impl, layer, builder_instance, indexer = _create_backend_impl( backend_cfg, @@ -794,6 +886,38 @@ def _run_mla_benchmark_batched( index_topk=index_topk if is_sparse else None, ) + # Verify the actual prefill backend matches what was requested + if prefill_backend is not None: + prefill_cfg = get_prefill_backend_config(prefill_backend) + fa_version = prefill_cfg["flash_attn_version"] + + if fa_version is not None: + # FA backend: verify the impl's FA version + actual_fa_version = getattr(impl, "vllm_flash_attn_version", None) + if actual_fa_version != fa_version: + raise RuntimeError( + f"Prefill backend '{prefill_backend}' requested FA " + f"version {fa_version}, but the impl is using FA " + f"version {actual_fa_version}. Check " + f"vllm/v1/attention/backends/fa_utils.py." + ) + else: + # Non-FA backend: verify the builder picked the right path + expected_flags = { + "flashinfer": "_use_fi_prefill", + "cudnn": "_use_cudnn_prefill", + "trtllm": "_use_trtllm_ragged_prefill", + } + flag_name = expected_flags.get(prefill_backend) + if flag_name and not getattr(builder_instance, flag_name, False): + raise RuntimeError( + f"Prefill backend '{prefill_backend}' was requested " + f"but the metadata builder did not enable it. This " + f"usually means a dependency is missing (e.g., " + f"flashinfer not installed) or the platform doesn't " + f"support it." + ) + # Run each benchmark with the shared impl for config, threshold, num_splits in configs_with_params: # Set threshold for this benchmark (FlashAttn/FlashMLA only) @@ -844,6 +968,7 @@ def run_mla_benchmark( reorder_batch_threshold: int | None = None, num_kv_splits: int | None = None, index_topk: int = 2048, + prefill_backend: str | None = None, ) -> BenchmarkResult | list[BenchmarkResult]: """ Unified MLA benchmark runner for all backends. @@ -861,6 +986,8 @@ def run_mla_benchmark( (single config mode only) num_kv_splits: Number of KV splits for CUTLASS (single config mode only) index_topk: Topk value for sparse MLA backends (default 2048) + prefill_backend: Prefill backend name (e.g., "fa3", "fa4"). + When set, forces the specified FlashAttention version for prefill. Returns: BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) @@ -884,7 +1011,9 @@ def run_mla_benchmark( return_single = True # Use unified batched execution - results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk) + results = _run_mla_benchmark_batched( + backend, configs_with_params, index_topk, prefill_backend=prefill_backend + ) # Return single result or list based on input return results[0] if return_single else results diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index dd184e38eb5e..a7e9e6ff5545 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -39,7 +39,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2 + GIT_TAG 1488682bb545f7d020e958a33116b1419d1cfc83 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 40108e490740..a8d2fd687fff 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -213,5 +213,5 @@ configuration. | `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | | `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/requirements/common.txt b/requirements/common.txt index 5e156edb75b0..893d6727ddcf 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -24,7 +24,7 @@ outlines_core == 0.2.11 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.29; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le" +xgrammar >= 0.1.30; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/entrypoints/openai/parser/test_harmony_utils.py b/tests/entrypoints/openai/parser/test_harmony_utils.py index 7842a1fcd757..21b53dff1507 100644 --- a/tests/entrypoints/openai/parser/test_harmony_utils.py +++ b/tests/entrypoints/openai/parser/test_harmony_utils.py @@ -14,6 +14,7 @@ parse_chat_output, ) from vllm.entrypoints.openai.responses.harmony import ( + response_input_to_harmony, response_previous_input_to_harmony, ) @@ -841,3 +842,89 @@ def test_all_standard_channels_present(self) -> None: assert channel in valid_channels, ( f"{channel} missing when with_custom_tools={with_tools}" ) + + +class TestResponseInputToHarmonyReasoningItem: + """Tests for response_input_to_harmony handling of reasoning input items. + + Per the OpenAI spec, ResponseReasoningItem.content is + Optional[List[Content]] = None. Clients like langchain-openai may omit + this field when constructing multi-turn input from previous responses. + + Reasoning items with content are converted to Harmony messages on the + 'analysis' channel. All content items are concatenated. Items without + content return None (skipped by the caller). + """ + + def test_reasoning_with_single_content(self): + """Test reasoning item with a single content entry.""" + item = { + "type": "reasoning", + "id": "rs_123", + "content": [{"type": "reasoning_text", "text": "Thinking step by step"}], + } + + msg = response_input_to_harmony(item, prev_responses=[]) + + assert msg is not None + assert msg.author.role == Role.ASSISTANT + assert msg.content[0].text == "Thinking step by step" + assert msg.channel == "analysis" + + def test_reasoning_with_multiple_content_items(self): + """Test reasoning item with multiple content entries concatenated.""" + item = { + "type": "reasoning", + "id": "rs_123", + "content": [ + {"type": "reasoning_text", "text": "First, let me analyze"}, + {"type": "reasoning_text", "text": "Second, I should consider"}, + {"type": "reasoning_text", "text": "Finally, the answer is"}, + ], + } + + msg = response_input_to_harmony(item, prev_responses=[]) + + assert msg is not None + assert msg.author.role == Role.ASSISTANT + assert msg.content[0].text == ( + "First, let me analyze\nSecond, I should consider\nFinally, the answer is" + ) + assert msg.channel == "analysis" + + def test_reasoning_without_content_returns_none(self): + """Test reasoning item without content field returns None.""" + item = { + "type": "reasoning", + "id": "rs_123", + "summary": [{"type": "summary_text", "text": "Thinking about math"}], + } + + msg = response_input_to_harmony(item, prev_responses=[]) + + assert msg is None + + def test_reasoning_with_none_content_returns_none(self): + """Test reasoning item with content=None returns None.""" + item = { + "type": "reasoning", + "id": "rs_123", + "content": None, + "summary": [{"type": "summary_text", "text": "Thinking about math"}], + } + + msg = response_input_to_harmony(item, prev_responses=[]) + + assert msg is None + + def test_reasoning_with_empty_content_returns_none(self): + """Test reasoning item with empty content list returns None.""" + item = { + "type": "reasoning", + "id": "rs_123", + "content": [], + } + + msg = response_input_to_harmony(item, prev_responses=[]) + + assert msg is None diff --git a/tests/kernels/attention/test_triton_decode_attention.py b/tests/kernels/attention/test_triton_decode_attention.py index f6b066a7bd1e..a9b881629441 100644 --- a/tests/kernels/attention/test_triton_decode_attention.py +++ b/tests/kernels/attention/test_triton_decode_attention.py @@ -90,3 +90,137 @@ def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): ) assert torch.allclose(o, o1) + + +def _quantize_to_fp8(tensor: torch.Tensor): + """Quantize a BF16 tensor to FP8 e4m3fn with per-tensor scale. + + Returns (fp8_tensor, scale) where: + fp8_tensor ≈ tensor / scale (stored as float8_e4m3fn) + tensor ≈ fp8_tensor.to(float32) * scale (dequantized) + """ + amax = tensor.abs().amax() + # float8_e4m3fn max representable value is 448.0 + scale = (amax / 448.0).clamp(min=1e-12).to(torch.float32) + fp8_tensor = ( + (tensor.to(torch.float32) / scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + ) + return fp8_tensor, scale + + +@pytest.mark.parametrize("B", [3]) +@pytest.mark.parametrize("L", [1025]) +@pytest.mark.parametrize("H_Q", [32]) +@pytest.mark.parametrize("H_KV", [32, 8]) +@pytest.mark.parametrize("D_QK", [128, 576]) +@pytest.mark.parametrize("D_V", [128, 512]) +@pytest.mark.parametrize("CACHE_SIZE", [16384]) +@pytest.mark.parametrize("PAGE_SIZE", [1, 16]) +def test_decode_attention_fp8(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE): + """Test FP8 KV cache path: quantize K/V to FP8, run kernel with scales, + and compare against BF16 reference output.""" + assert CACHE_SIZE % PAGE_SIZE == 0 + dtype = torch.bfloat16 + seq_len = L + sm_scale = 1.0 / (D_QK**0.5) + num_kv_splits = 8 + + num_pages_per_batch = cdiv(seq_len, PAGE_SIZE) + req_to_page = torch.randint( + 0, CACHE_SIZE // PAGE_SIZE, (B, num_pages_per_batch, 1), device="cuda" + ) + req_to_token = req_to_page * PAGE_SIZE + req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE) + req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(1, 1, -1) + req_to_token = req_to_token.view(B, -1) + req_to_token = req_to_token[:, :seq_len].contiguous() + + q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda") + + # Create BF16 K/V as reference + k_bf16 = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda") + v_bf16 = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda") + + # --- BF16 reference --- + o_ref = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + lse_ref = torch.zeros(B, H_Q, dtype=dtype, device="cuda") + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" + ) + + if PAGE_SIZE == 1: + decode_attention_fwd( + q, + k_bf16, + v_bf16, + o_ref, + lse_ref, + req_to_token, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + ) + else: + k_paged = k_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) + v_paged = v_bf16.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) + decode_attention_fwd( + q, + k_paged, + v_paged, + o_ref, + lse_ref, + req_to_page, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + page_size=PAGE_SIZE, + ) + + # --- FP8 path --- + k_fp8, k_scale = _quantize_to_fp8(k_bf16) + v_fp8, v_scale = _quantize_to_fp8(v_bf16) + + o_fp8 = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + lse_fp8 = torch.zeros(B, H_Q, dtype=dtype, device="cuda") + attn_logits_fp8 = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, device="cuda" + ) + + if PAGE_SIZE == 1: + decode_attention_fwd( + q, + k_fp8, + v_fp8, + o_fp8, + lse_fp8, + req_to_token, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits_fp8, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + k_scale=k_scale, + v_scale=v_scale, + ) + else: + k_fp8_paged = k_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK) + v_fp8_paged = v_fp8.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V) + decode_attention_fwd( + q, + k_fp8_paged, + v_fp8_paged, + o_fp8, + lse_fp8, + req_to_page, + b_seq_len=torch.full((B,), seq_len, device="cuda"), + attn_logits=attn_logits_fp8, + num_kv_splits=num_kv_splits, + sm_scale=sm_scale, + page_size=PAGE_SIZE, + k_scale=k_scale, + v_scale=v_scale, + ) + + # FP8 tolerances match test_mla_backends.py test_backend_correctness. + torch.testing.assert_close(o_ref, o_fp8, atol=5e-1, rtol=1e-2) diff --git a/tests/kernels/test_fused_recurrent_packed_decode.py b/tests/kernels/test_fused_recurrent_packed_decode.py index f81f3c776e98..d63186bde118 100644 --- a/tests/kernels/test_fused_recurrent_packed_decode.py +++ b/tests/kernels/test_fused_recurrent_packed_decode.py @@ -10,7 +10,7 @@ ) -@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA device") @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("strided_mixed_qkv", [False, True]) def test_fused_recurrent_packed_decode_matches_reference( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fdc468d3b25d..d3d5b80742cb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -30,6 +30,7 @@ def register_fake(fn): # page attention ops +# Enhanced paged_attention_v1 with device checks def paged_attention_v1( out: torch.Tensor, query: torch.Tensor, @@ -48,3363 +49,20 @@ def paged_attention_v1( tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, -) -> None: - torch.ops._C.paged_attention_v1( - out, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - - -def paged_attention_v2( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - block_size: int, - max_seq_len: int, - alibi_slopes: torch.Tensor | None, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - tp_rank: int = 0, - blocksparse_local_blocks: int = 0, - blocksparse_vert_stride: int = 0, - blocksparse_block_size: int = 64, - blocksparse_head_sliding_step: int = 0, -) -> None: - torch.ops._C.paged_attention_v2( - out, - exp_sum, - max_logits, - tmp_out, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - tp_rank, - blocksparse_local_blocks, - blocksparse_vert_stride, - blocksparse_block_size, - blocksparse_head_sliding_step, - ) - - -def paged_attention_rocm( - out: torch.Tensor, - exp_sum: torch.Tensor, - max_logits: torch.Tensor, - tmp_out: torch.Tensor, - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - num_kv_heads: int, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - query_start_loc: torch.Tensor | None, - block_size: int, - max_seq_len: int, - alibi_slopes: torch.Tensor | None, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - fp8_out_scale: torch.Tensor | None = None, - mfma_type: str = "fp8" if envs.VLLM_ROCM_FP8_MFMA_PAGE_ATTN else "f16", -) -> None: - torch.ops._rocm_C.paged_attention( - out, - exp_sum, - max_logits, - tmp_out, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - query_start_loc, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - fp8_out_scale, - mfma_type, - ) - - -def mla_decode_kvcache_cpu( - out: torch.Tensor, - query: torch.Tensor, - kv_cache: torch.Tensor, - scale: float, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, -) -> None: - torch.ops._C.mla_decode_kvcache(out, query, kv_cache, scale, block_tables, seq_lens) - - -# merge attn states ops -def merge_attn_states( - output: torch.Tensor, - prefix_output: torch.Tensor, - prefix_lse: torch.Tensor, - suffix_output: torch.Tensor, - suffix_lse: torch.Tensor, - output_lse: torch.Tensor | None = None, -) -> None: - torch.ops._C.merge_attn_states( - output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse - ) - - -def convert_vertical_slash_indexes( - q_seqlens: torch.Tensor, # [BATCH, ] - kv_seqlens: torch.Tensor, # [BATCH, ] - vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] - slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] - context_size: int, - block_size_M: int, - block_size_N: int, - causal: bool = True, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size = slash_indexes.size(0) - num_heads = slash_indexes.size(1) - nnz_slash = slash_indexes.size(2) - nnz_vertical = vertical_indexes.size(2) - num_rows = (context_size + block_size_M - 1) // block_size_M - - block_count = torch.zeros( - batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device - ) - block_offset = torch.zeros( - batch_size, - num_heads, - num_rows, - nnz_slash, - dtype=q_seqlens.dtype, - device=q_seqlens.device, - ) - column_count = torch.zeros( - batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device - ) - column_index = torch.zeros( - batch_size, - num_heads, - num_rows, - nnz_vertical, - dtype=q_seqlens.dtype, - device=q_seqlens.device, - ) - - torch.ops._C.convert_vertical_slash_indexes( - block_count, - block_offset, - column_count, - column_index, - q_seqlens, - kv_seqlens, - vertical_indexes, - slash_indexes, - context_size, - block_size_M, - block_size_N, - causal, - ) - return block_count, block_offset, column_count, column_index - - -def convert_vertical_slash_indexes_mergehead( - q_seqlens: torch.Tensor, # [BATCH, ] - kv_seqlens: torch.Tensor, # [BATCH, ] - vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] - slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] - # [N_HEADS] : different head use different number of indices - vertical_indices_count: torch.Tensor, - slash_indices_count: torch.Tensor, - context_size: int, - block_size_M: int, - block_size_N: int, - causal: bool = True, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - batch_size = slash_indexes.size(0) - num_heads = slash_indexes.size(1) - nnz_slash = slash_indexes.size(2) - nnz_vertical = vertical_indexes.size(2) - num_rows = (context_size + block_size_M - 1) // block_size_M - - block_count = torch.empty( - batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device - ) - block_offset = torch.empty( - batch_size, - num_heads, - num_rows, - nnz_slash, - dtype=q_seqlens.dtype, - device=q_seqlens.device, - ) - column_count = torch.empty( - batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device - ) - column_index = torch.empty( - batch_size, - num_heads, - num_rows, - nnz_vertical, - dtype=q_seqlens.dtype, - device=q_seqlens.device, - ) - - torch.ops._C.convert_vertical_slash_indexes_mergehead( - block_count, - block_offset, - column_count, - column_index, - q_seqlens, - kv_seqlens, - vertical_indexes, - slash_indexes, - vertical_indices_count, - slash_indices_count, - context_size, - block_size_M, - block_size_N, - causal, - ) - return block_count, block_offset, column_count, column_index - - -# pos encoding ops -def rotary_embedding( - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor | None, - head_size: int, - cos_sin_cache: torch.Tensor, - is_neox: bool, -) -> None: - torch.ops._C.rotary_embedding( - positions, query, key, head_size, cos_sin_cache, is_neox - ) - - -# layer norm ops -def rms_norm( - out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, epsilon: float -) -> None: - torch.ops._C.rms_norm(out, input, weight, epsilon) - - -def fused_add_rms_norm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, epsilon: float -) -> None: - torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - - -def fused_qk_norm_rope( - qkv: torch.Tensor, - num_heads_q: int, - num_heads_k: int, - num_heads_v: int, - head_dim: int, - eps: float, - q_weight: torch.Tensor, - k_weight: torch.Tensor, - cos_sin_cache: torch.Tensor, - is_neox: bool, - position_ids: torch.Tensor, -) -> None: - torch.ops._C.fused_qk_norm_rope( - qkv, - num_heads_q, - num_heads_k, - num_heads_v, - head_dim, - eps, - q_weight, - k_weight, - cos_sin_cache, - is_neox, - position_ids, - ) - - -def apply_repetition_penalties_torch( - logits: torch.Tensor, - prompt_mask: torch.Tensor, - output_mask: torch.Tensor, - repetition_penalties: torch.Tensor, -) -> None: - repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat( - 1, logits.size(1) - ) - # If token appears in prompt or output, apply, otherwise use 1.0 for no-op. - penalties = torch.where(prompt_mask | output_mask, repetition_penalties, 1.0) - # If logits are positive, divide by penalty, otherwise multiply by penalty. - scaling = torch.where(logits > 0, 1.0 / penalties, penalties) - logits *= scaling - - -def apply_repetition_penalties_cuda( - logits: torch.Tensor, - prompt_mask: torch.Tensor, - output_mask: torch.Tensor, - repetition_penalties: torch.Tensor, -) -> None: - torch.ops._C.apply_repetition_penalties_( - logits, prompt_mask, output_mask, repetition_penalties - ) - - -def apply_repetition_penalties( - logits: torch.Tensor, - prompt_mask: torch.Tensor, - output_mask: torch.Tensor, - repetition_penalties: torch.Tensor, -) -> None: - """Apply repetition penalties to logits in-place. - - Args: - logits: The logits tensor of shape [num_seqs, vocab_size]. - prompt_mask: A boolean tensor indicating which tokens appear in the prompt. - output_mask: A boolean tensor indicating which tokens appear in the output. - repetition_penalties: The repetition penalties of shape (num_seqs, ). - """ - if logits.is_cuda and logits.is_contiguous(): - apply_repetition_penalties_cuda( - logits, prompt_mask, output_mask, repetition_penalties - ) - else: - apply_repetition_penalties_torch( - logits, prompt_mask, output_mask, repetition_penalties - ) - - -# fused quant layer norm ops -def rms_norm_dynamic_per_token_quant( - input: torch.Tensor, - weight: torch.Tensor, - epsilon: float, - quant_dtype: torch.dtype, - scale_ub: torch.Tensor | None = None, - residual: torch.Tensor | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: - output = torch.empty(input.shape, dtype=quant_dtype, device=input.device) - scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - - torch.ops._C.rms_norm_dynamic_per_token_quant( - output, input, weight, scales, epsilon, scale_ub, residual - ) - return output, scales - - -# fused quant layer norm ops blocked -def rms_norm_per_block_quant( - input: torch.Tensor, - weight: torch.Tensor, - epsilon: float, - quant_dtype: torch.dtype, - group_size: list[int], - scale_ub: torch.Tensor | None = None, - residual: torch.Tensor | None = None, - is_scale_transposed: bool = False, - tma_alignment: int = 0, -) -> tuple[torch.Tensor, torch.Tensor]: - assert len(group_size) == 2 - output = torch.empty(input.shape, dtype=quant_dtype, device=input.device) - if is_scale_transposed: - if tma_alignment == 0: - scales = torch.empty( - (input.shape[-1] // group_size[1], input.numel() // input.shape[-1]), - device=input.device, - dtype=torch.float32, - ).transpose(0, 1) - else: - m = input.shape[-2] - sf_k = input.shape[-1] // group_size[1] - tma_aligned_m = (m + tma_alignment - 1) // tma_alignment * tma_alignment - shape = input.shape[:-2] + (m, sf_k) - stride = ( - (1, tma_aligned_m) - if input.dim() == 2 - else (tma_aligned_m * sf_k, 1, tma_aligned_m) - ) - scales = torch.empty_strided( - shape, stride, device=input.device, dtype=torch.float32 - ) - else: - scales = torch.empty( - (input.numel() // input.shape[-1], input.shape[-1] // group_size[1]), - device=input.device, - dtype=torch.float32, - ) - - assert tma_alignment in [0, 4], "Expected TMA alignment 0 or 4, but got " + str( - tma_alignment - ) - - torch.ops._C.rms_norm_per_block_quant( - output, - input, - weight, - scales, - epsilon, - scale_ub, - residual, - group_size[1], - is_scale_transposed, - ) - return output, scales - - -# quantization ops -# awq -def awq_dequantize( - qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - split_k_iters: int, - thx: int, - thy: int, -) -> torch.Tensor: - if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import ( - awq_dequantize_triton, - ) - - return awq_dequantize_triton(qweight, scales, zeros) - return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) - - -if hasattr(torch.ops._C, "awq_dequantize"): - - @register_fake("_C::awq_dequantize") - def _awq_dequantize_fake( - qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - split_k_iters: torch.SymInt, - thx: int, - thy: int, - ) -> torch.Tensor: - in_c = qweight.size(0) - qout_c = qweight.size(1) - out_c = qout_c * 8 - return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) - - -def awq_gemm( - input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - split_k_iters: int, -) -> torch.Tensor: - if envs.VLLM_USE_TRITON_AWQ: - from vllm.model_executor.layers.quantization.awq_triton import awq_gemm_triton - - return awq_gemm_triton(input, qweight, scales, qzeros, split_k_iters) - return torch.ops._C.awq_gemm(input, qweight, scales, qzeros, split_k_iters) - - -if hasattr(torch.ops._C, "awq_gemm"): - - @register_fake("_C::awq_gemm") - def _awq_gemm_fake( - input: torch.Tensor, - qweight: torch.Tensor, - scales: torch.Tensor, - qzeros: torch.Tensor, - split_k_iters: torch.SymInt, - ) -> torch.Tensor: - num_in_feats = input.size(0) - return torch.empty( - (split_k_iters, num_in_feats, qweight.size(1) * 8), - dtype=input.dtype, - device=input.device, - ).sum(0) - - -# gptq -def gptq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, - use_exllama: bool, - use_v2_format: bool, - bit: int, -) -> torch.Tensor: - return torch.ops._C.gptq_gemm( - a, - b_q_weight, - b_gptq_qzeros, - b_gptq_scales, - b_g_idx, - use_exllama, - use_v2_format, - bit, - ) - - -if hasattr(torch.ops._C, "gptq_gemm"): - - @register_fake("_C::gptq_gemm") - def _gptq_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, - use_exllama: bool, - use_v2_format: bool, - bit: int, - ) -> torch.Tensor: - return torch.empty( - (a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device - ) - - -def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: - torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) - - -if hasattr(torch.ops._C, "allspark_w8a16_gemm"): - - @register_fake("_C::allspark_w8a16_gemm") - def _allspark_w8a16_gemm_fake( - a: torch.Tensor, - b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: torch.Tensor | None, - n: torch.SymInt, - group_size: torch.SymInt, - sm_count: torch.SymInt, - sm_version: torch.SymInt, - CUBLAS_M_THRESHOLD: torch.SymInt, - has_zp: bool, - n32k16_reorder: bool, - ) -> torch.Tensor: - m = a.size(0) - return torch.empty((m, n), device=a.device, dtype=a.dtype) - - -if hasattr(torch.ops._C, "ggml_dequantize"): - - @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake( - W: torch.Tensor, - quant_type: int, - m: torch.SymInt, - n: torch.SymInt, - dtype: torch.dtype | None = None, - ) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device) - - @register_fake("_C::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=X.dtype, device=W.device) - - @register_fake("_C::ggml_moe_a8") - def _ggml_moe_a8_fake( - X: torch.Tensor, - W: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - quant_type: int, - row: torch.SymInt, - top_k: torch.SymInt, - tokens: torch.SymInt, - ) -> torch.Tensor: - tokens = X.size(0) - return torch.empty((tokens * top_k, row), dtype=torch.float16, device=W.device) - - -if hasattr(torch.ops._C, "ggml_moe_a8_vec"): - - @register_fake("_C::ggml_moe_a8_vec") - def _ggml_moe_a8_vec_fake( - X: torch.Tensor, - W: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - quant_type: int, - row: torch.SymInt, - tokens: torch.SymInt, - ) -> torch.Tensor: - tokens = X.size(0) - return torch.empty((tokens * top_k, row), dtype=X.dtype, device=W.device) - - -# cutlass -def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) - - -def cutlass_scaled_fp4_mm( - a: torch.Tensor, - b: torch.Tensor, - block_scale_a: torch.Tensor, - block_scale_b: torch.Tensor, - alpha: torch.Tensor, - out_dtype: torch.dtype, -) -> torch.Tensor: - assert a.ndim == 2 and b.ndim == 2 - m, n = a.shape[0], b.shape[0] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, alpha) - return out - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: torch.Tensor | None = None, -) -> torch.Tensor: - """ - `cutlass_scaled_mm` implements a fused version of - `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` - where scale_a * a and scale_b * b are implemented using numpy-style - broadcasting. - - In order to support blockwise scaling like found in DeepSeek V3 we also - support extended "group" broadcast rules. We extend the numpy-style - broadcasting rules with the following rule: - "if the extent of a dimension in the source shape is between 1 and - corresponding extent in the target shape we repeat each element along - that dimension src_shape[dim] // target_shape[dim] times consecutively" - example if we have: - a = [[1, 2], and target_shape = (2, 4) - [3, 4]] - then we would expand a to: - a = [[1, 1, 2, 2], - [3, 3, 4, 4]] - currently we only support the case: - scale_a.shape * [1, 128] == a.shape - scale_b.shape * [128, 128] == b.shape - """ - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - - # Massage the input to be 2D - target_shape = (*a.shape[:-1], b.shape[1]) - a = a.view(-1, a.shape[-1]) - - cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - if current_platform.is_rocm() or not cutlass_compatible_b: - from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa - triton_scaled_mm, - ) - - out = triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - else: - out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out.view(*target_shape) - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: torch.Tensor | None = None, - bias: torch.Tensor | None = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - - # Massage the input to be 2D - target_shape = (*a.shape[:-1], b.shape[1]) - a = a.view(-1, a.shape[-1]) - assert azp is None or azp.numel() == a.shape[0] - - out = torch.empty((a.shape[0], b.shape[1]), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out.view(*target_shape) - - -def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool: - return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability) - - -def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: - if cuda_device_capability < 90 or cuda_device_capability >= 110: - return False - try: - return torch.ops._C.cutlass_group_gemm_supported(cuda_device_capability) - except AttributeError: - # Return False on non-CUDA platforms where it is not available - return False - - -def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compresses a sparse matrix for use with Cutlass sparse operations. - - This function takes a dense tensor and compresses it into two components: - non-zero elements and metadata. The compressed representation is compatible - with Cutlass sparse kernels. - - Args: - a (torch.Tensor): - The input tensor to be compressed. Must have one of the following data types: - - `torch.int8` - - `torch.float8_e4m3fn` - - `torch.bfloat16` - - `torch.float16` - - Returns: - tuple[torch.Tensor, torch.Tensor]: - A tuple containing: - - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. - - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. - - Raises: - ValueError: If the compression operation fails. - - Notes: - - The `a_meta` tensor has a data type of `torch.uint8`. - - Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`). - - The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor. - - The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`. - """ - assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16] - assert a.is_contiguous() - - # a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4 - elemsPerMetaElem = 4 - assert a.shape[1] % (2 * elemsPerMetaElem) == 0 - - return torch.ops._C.cutlass_sparse_compress(a) - - -def cutlass_scaled_sparse_mm( - a: torch.Tensor, - bt_nzs: torch.Tensor, - bt_meta: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: torch.Tensor | None = None, -) -> torch.Tensor: - """ - Performs a scaled sparse matrix multiplication using Cutlass. - - Steps: - 1. Create a dense matrix `a` of shape (m, k) on the CUDA device: - `a = torch.randn((m, k), device='cuda')`. - - 2. Create a dense matrix `b` of shape (k, n) on the CUDA device: - `b = torch.randn((k, n), device='cuda')`. - - 3. Prune matrix `b` to 2:4 sparsity along the specified dimension: - `b = prune_to_2_4(b, dim=0)`. - - 4. Compress the transposed sparse matrix `b.t()`: - `bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`. - - 5. Perform sparse matrix multiplication using the compressed matrix, - applying scaling factors for `a` and `b`, and the output data type: - `out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`. - - Returns: - - The result of the scaled sparse matrix multiplication. - """ - assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype - - m = a.shape[0] - n = bt_nzs.shape[0] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - torch.ops._C.cutlass_scaled_sparse_mm( - out, a, bt_nzs, bt_meta, scale_a, scale_b, bias - ) - - return out - - -def get_cutlass_moe_mm_data( - topk_ids: torch.Tensor, - expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - input_permutation: torch.Tensor, - output_permutation: torch.Tensor, - num_experts: int, - n: int, - k: int, - blockscale_offsets: torch.Tensor | None = None, -): - """ - Prepare data necessary to perform CUTLASS grouped matrix multiplications - used in CUTLASS-based fused MoE. - - The function takes in topk_ids (token-expert mapping) and uses it to - compute: - - expert_offsets: Indices that mark at which token index each expert begins - its computation after the input is sorted with - input_permutation. The number of tokens computed with - expert E is expert_offsets[E + 1] - expert_offsets[E] - - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's - multiplication in two grouped MMs used in - the fused MoE operation. - - input_permutation: Permutation that must be used to shuffle the input - before executing the MMs. - - output_permutation: Permutation that must be used to shuffle the output - after executing the MMs. - - blockscale_offsets: Optional argument passed for fp4 moe. Indices that - mark at which block scale index each expert begins - its computation. The number of block scale rows - computed with expert E is blockscale_offsets[E + 1] - - blockscale_offsets[E] - """ - return torch.ops._C.get_cutlass_moe_mm_data( - topk_ids, - expert_offsets, - problem_sizes1, - problem_sizes2, - input_permutation, - output_permutation, - num_experts, - n, - k, - blockscale_offsets, - ) - - -def get_cutlass_moe_mm_problem_sizes_from_expert_offsets( - expert_first_token_offset: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - n: int, - k: int, - swap_ab: bool, -): - """Compute per-expert (M, N, K) problem sizes from expert_first_token_offset""" - return torch.ops._C.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( - expert_first_token_offset, - problem_sizes1, - problem_sizes2, - n, - k, - swap_ab, - ) - - -def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): - """ - Shuffle and expand the input tensor according to the dst2src_map and store the result in output_tensor. - This is used in MoE to permute the input tensor before performing grouped matrix multiplications. - """ - num_tokens_permuted = dst2src_map.shape[0] - output_tensor = torch.empty( - (num_tokens_permuted, input_tensor.shape[1]), - device=input_tensor.device, - dtype=input_tensor.dtype, - ) - torch.ops._moe_C.shuffle_rows(input_tensor, dst2src_map, output_tensor) - return output_tensor - - -def get_cutlass_batched_moe_mm_data( - expert_offsets: torch.Tensor, - problem_sizes1: torch.Tensor, - problem_sizes2: torch.Tensor, - expert_num_tokens: torch.Tensor, - num_local_experts: int, - padded_m: int, - n: int, - k: int, -): - """ - Prepare data necessary to perform CUTLASS grouped matrix multiplications - used in CUTLASS-based fused MoE. - - The function takes in expert_num_tokens (token count per expert) and - non_zero_expert_idxs (consecutive indices of experts with non-zero token - counts) and uses them to compute: - - expert_offsets: Indices that mark at which token index each expert begins - its computation. - - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's - multiplication in two grouped MMs used in - the fused MoE operation. - """ - return torch.ops._C.get_cutlass_batched_moe_mm_data( - expert_offsets, - problem_sizes1, - problem_sizes2, - expert_num_tokens, - num_local_experts, - padded_m, - n, - k, - ) - - -def cutlass_moe_mm( - out_tensors: torch.Tensor, - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, - a_strides: torch.Tensor, - b_strides: torch.Tensor, - c_strides: torch.Tensor, - per_act_token: bool, - per_out_ch: bool, -): - """ - A single grouped matrix multiplication used in CUTLASS-based fused MoE. - The function executes fp8-quantized OUT = AB matrix multiplication. - - - expert_offsets: Indices that mark at which token index each expert begins - its computation. The number of tokens computed with - expert E is expert_offsets[E + 1] - expert_offsets[E] - - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped - MMs used in the fused MoE operation. - - a/b/c_strides: The data strides passed to grouped matrix multiplication. - """ - return torch.ops._C.cutlass_moe_mm( - out_tensors, - a_tensors, - b_tensors, - a_scales, - b_scales, - expert_offsets, - problem_sizes, - a_strides, - b_strides, - c_strides, - per_act_token, - per_out_ch, - ) - - -def cutlass_fp4_moe_mm( - out_tensors: torch.Tensor, - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - alphas: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - sf_offsets: torch.Tensor, -): - """ - An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs - the gemms for each combination based on the specified problem sizes. - - This is used as the MoE gemm during NVFP4 Quantized FusedMoE forward. - - a/b_tensors: the NVFP4 a_ptrs and b_ptrs tensors which are quantized - input and expert weights. - - a_/b_scales: The blockscales in FP8-E4M3 precision - - expert_offsets/sf_offsets: Indices that mark at which token index - each expert begins its computation. The number of tokens - computed with expert E is expert_offsets[E + 1] - - expert_offsets[E] And the sf_size per expert is - sf_offset[E+1] - sf_offset[E] - - problem_sizes: MxNxK sizes of each expert's multiplication in two grouped - MMs used in the fused MoE operation. - """ - return torch.ops._C.cutlass_fp4_group_mm( - out_tensors, - a_tensors, - b_tensors, - a_scales, - b_scales, - alphas, - problem_sizes, - expert_offsets, - sf_offsets, - ) - - -def mxfp8_experts_quant( - input_tensor: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - quant_output: torch.Tensor, - scale_factor: torch.Tensor, -) -> None: - torch.ops._C.mxfp8_experts_quant( - input_tensor, - problem_sizes, - expert_offsets, - blockscale_offsets, - quant_output, - scale_factor, - ) - - -def cutlass_mxfp8_grouped_mm( - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - out_tensors: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, -) -> None: - torch.ops._C.cutlass_mxfp8_grouped_mm( - a_tensors, - b_tensors, - a_scales, - b_scales, - out_tensors, - problem_sizes, - expert_offsets, - blockscale_offsets, - ) - - -if hasattr(torch.ops._C, "mxfp8_experts_quant"): - - @register_fake("_C::mxfp8_experts_quant") - def _mxfp8_experts_quant_fake( - input_tensor: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - quant_output: torch.Tensor, - scale_factor: torch.Tensor, - ) -> None: - return None - - -if hasattr(torch.ops._C, "cutlass_mxfp8_grouped_mm"): - - @register_fake("_C::cutlass_mxfp8_grouped_mm") - def _cutlass_mxfp8_grouped_mm_fake( - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - out_tensors: torch.Tensor, - problem_sizes: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - ) -> None: - return None - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, - is_a_8bit: bool = False, -) -> torch.Tensor: - return torch.ops._C.gptq_marlin_repack( - b_q_weight, perm, size_k, size_n, num_bits, is_a_8bit - ) - - -if hasattr(torch.ops._C, "gptq_marlin_repack"): - - @register_fake("_C::gptq_marlin_repack") - def _gptq_marlin_repack_fake( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: torch.SymInt, - size_n: torch.SymInt, - num_bits: int, - is_a_8bit: bool = False, - ) -> torch.Tensor: - pack_factor = 32 // num_bits - marlin_tile_size = 16 - return torch.empty( - (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), - dtype=b_q_weight.dtype, - device=b_q_weight.device, - ) - - -# awq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, - is_a_8bit: bool = False, -) -> torch.Tensor: - return torch.ops._C.awq_marlin_repack( - b_q_weight, size_k, size_n, num_bits, is_a_8bit - ) - - -if hasattr(torch.ops._C, "awq_marlin_repack"): - - @register_fake("_C::awq_marlin_repack") - def _awq_marlin_repack_fake( - b_q_weight: torch.Tensor, - size_k: torch.SymInt, - size_n: torch.SymInt, - num_bits: int, - is_a_8bit: bool = False, - ) -> torch.Tensor: - pack_factor = 32 // num_bits - marlin_tile_size = 16 - return torch.empty( - (size_k // marlin_tile_size, size_n * marlin_tile_size // pack_factor), - dtype=b_q_weight.dtype, - device=b_q_weight.device, - ) - - -def gptq_marlin_moe_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, - is_a_8bit: bool = False, -) -> torch.Tensor: - num_experts = b_q_weight.shape[0] - assert size_k % 16 == 0 - output = torch.empty( - (num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype, - ) - for e in range(num_experts): - output[e] = torch.ops._C.gptq_marlin_repack( - b_q_weight[e], perm[e], size_k, size_n, num_bits, is_a_8bit - ) - return output - - -def awq_marlin_moe_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, - is_a_8bit: bool = False, -) -> torch.Tensor: - num_experts = b_q_weight.shape[0] - assert size_k % 16 == 0 - output = torch.empty( - (num_experts, size_k // 16, size_n * (num_bits // 2)), - device=b_q_weight.device, - dtype=b_q_weight.dtype, - ) - for e in range(num_experts): - output[e] = torch.ops._C.awq_marlin_repack( - b_q_weight[e], size_k, size_n, num_bits, is_a_8bit - ) - return output - - -def marlin_int4_fp8_preprocess( - qweight: torch.Tensor, - qzeros_or_none: torch.Tensor | None = None, - inplace: bool = False, -): - return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace) - - -def marlin_gemm( - a: torch.Tensor, - c: torch.Tensor | None, - b_q_weight: torch.Tensor, - b_bias: torch.Tensor | None, - b_scales: torch.Tensor, - a_scales: torch.Tensor | None, - global_scale: torch.Tensor | None, - b_zeros: torch.Tensor | None, - g_idx: torch.Tensor | None, - perm: torch.Tensor | None, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return torch.ops._C.marlin_gemm( - a, - c, - b_q_weight, - b_bias, - b_scales, - a_scales, - global_scale, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - use_atomic_add, - use_fp32_reduce, - is_zp_float, - ) - - -if hasattr(torch.ops._C, "marlin_gemm"): - - @register_fake("_C::marlin_gemm") - def _marlin_gemm_fake( - a: torch.Tensor, - c: torch.Tensor | None, - b_q_weight: torch.Tensor, - b_bias: torch.Tensor | None, - b_scales: torch.Tensor, - a_scales: torch.Tensor | None, - global_scale: torch.Tensor | None, - b_zeros: torch.Tensor | None, - g_idx: torch.Tensor | None, - perm: torch.Tensor | None, - workspace: torch.Tensor, - b_q_type_id: int, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool = True, - use_atomic_add: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, - ) -> torch.Tensor: - dtype = a.dtype - if dtype not in [torch.half, torch.bfloat16]: - dtype = b_scales.dtype - return torch.empty((size_m, size_n), device=a.device, dtype=dtype) - - -# machete -def machete_supported_schedules( - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: torch.dtype | None, - group_zeros_type: torch.dtype | None = None, - channel_scales_type: torch.dtype | None = None, - token_scales_type: torch.dtype | None = None, - out_type: torch.dtype | None = None, -) -> list[str]: - return torch.ops._C.machete_supported_schedules( - a_type, - b_type.id, - group_scales_type, - group_zeros_type, - channel_scales_type, - token_scales_type, - out_type, - ) - - -def machete_mm( - a: torch.Tensor, - # b_q Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type: ScalarType, - out_type: torch.dtype | None = None, - b_group_scales: torch.Tensor | None = None, - b_group_zeros: torch.Tensor | None = None, - b_group_size: int | None = None, - b_channel_scales: torch.Tensor | None = None, - a_token_scales: torch.Tensor | None = None, - schedule: str | None = None, -) -> torch.Tensor: - return torch.ops._C.machete_mm( - a, - b_q, - b_type.id, - out_type, - b_group_scales, - b_group_zeros, - b_group_size, - b_channel_scales, - a_token_scales, - schedule, - ) - - -if hasattr(torch.ops._C, "machete_mm"): - - @register_fake("_C::machete_mm") - def machete_mm_fake( - a: torch.Tensor, - # b_q Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type: ScalarType, - out_type: torch.dtype | None = None, - b_group_scales: torch.Tensor | None = None, - b_group_zeros: torch.Tensor | None = None, - b_group_size: int | None = None, - b_channel_scales: torch.Tensor | None = None, - a_token_scales: torch.Tensor | None = None, - schedule: str | None = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - return torch.empty((m, n), device=a.device, dtype=a.dtype) - - -def machete_prepack_B( - b_q_weight: torch.Tensor, - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: torch.dtype | None, -) -> torch.Tensor: - return torch.ops._C.machete_prepack_B( - b_q_weight, a_type, b_type.id, group_scales_type - ) - - -if hasattr(torch.ops._C, "machete_prepack_B"): - - @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake( - b_q_weight: torch.Tensor, - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: torch.dtype | None, - ) -> torch.Tensor: - return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - - -# CUTLASS W4A8 -def cutlass_w4a8_mm( - a: torch.Tensor, - # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b - b_q: torch.Tensor, - b_group_scales: torch.Tensor, - b_group_size: int, - b_channel_scales: torch.Tensor, - a_token_scales: torch.Tensor, - out_type: torch.dtype | None = None, - maybe_schedule: str | None = None, -) -> torch.Tensor: - return torch.ops._C.cutlass_w4a8_mm( - a, - b_q, - b_group_scales, - b_group_size, - b_channel_scales, - a_token_scales, - out_type, - maybe_schedule, - ) - - -if hasattr(torch.ops._C, "cutlass_w4a8_mm"): - - @register_fake("_C::cutlass_w4a8_mm") - def cutlass_w4a8_mm_fake( - a: torch.Tensor, - # b_q Should be the tensor returned by cutlass_encode_and_reorder_int4b - b_q: torch.Tensor, - b_group_scales: torch.Tensor, - b_group_size: int, - b_channel_scales: torch.Tensor, - a_token_scales: torch.Tensor, - out_type: torch.dtype | None = None, - maybe_schedule: str | None = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - out_dtype = out_type if out_type is not None else torch.bfloat16 - return torch.empty((m, n), device=a.device, dtype=out_dtype) - - -def cutlass_pack_scale_fp8(scales: torch.Tensor) -> torch.Tensor: - return torch.ops._C.cutlass_pack_scale_fp8(scales) - - -if hasattr(torch.ops._C, "cutlass_pack_scale_fp8"): - - @register_fake("_C::cutlass_pack_scale_fp8") - def cutlass_pack_scale_fp8_fake(scales: torch.Tensor) -> torch.Tensor: - return torch.empty_like(scales, memory_format=torch.contiguous_format) - - -def cutlass_encode_and_reorder_int4b(b: torch.Tensor) -> torch.Tensor: - return torch.ops._C.cutlass_encode_and_reorder_int4b(b) - - -if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b"): - - @register_fake("_C::cutlass_encode_and_reorder_int4b") - def cutlass_encode_and_reorder_int4b_fake(b: torch.Tensor) -> torch.Tensor: - return torch.empty_like(b, memory_format=torch.contiguous_format) - - -def cutlass_w4a8_moe_mm( - out_tensors: torch.Tensor, - a_tensors: torch.Tensor, - b_tensors: torch.Tensor, - a_scales: torch.Tensor, - b_scales: torch.Tensor, - b_group_scales: torch.Tensor, - b_group_size: int, - expert_offsets: torch.Tensor, - problem_sizes: torch.Tensor, - a_strides: torch.Tensor, - b_strides: torch.Tensor, - c_strides: torch.Tensor, - group_scale_strides: torch.Tensor, - maybe_schedule: str | None = None, -): - """ - Executes the CUTLASS-based fused-MoE grouped matrix multiplication for the - W4A8 quantization scheme. Uses group-wise quantization (INT4 -> FP8) - and both per-channel + per-token scaling in the epilogue. - - Args: - out_tensors: - Output buffer for all experts (updated in-place). - a_tensors: - FP8 (E4M3FN) activations for all experts. - b_tensors: - INT4-packed weight matrix for all experts, packed to INT32 - a_scales: - Per-token FP8 activation scales, applied in the epilogue. - b_scales: - Per-channel FP8 weight scales for each expert, applied in the epilogue. - b_group_scales: - FP8 scale values for group-wise INT4 weight blocks. - b_group_size: - Number of elements grouped under each entry of b_group_scales. - expert_offsets: - Cumulative token offsets - problem_sizes: - Per-expert (M, N, K) GEMM sizes used by the grouped GEMM launcher. - a/b/c/group_scale_strides: - Strides describing the memory layout of the input tensors. - maybe_schedule: - Optional override to choose a specific kernel or epilogue schedule. - - Returns: - out_tensors updated in-place with the dequantized INT4xFP8 grouped GEMM result. - """ - return torch.ops._C.cutlass_w4a8_moe_mm( - out_tensors, - a_tensors, - b_tensors, - a_scales, - b_scales, - b_group_scales, - b_group_size, - expert_offsets, - problem_sizes, - a_strides, - b_strides, - c_strides, - group_scale_strides, - maybe_schedule, - ) - - -def cutlass_encode_and_reorder_int4b_grouped( - b_tensors: torch.Tensor, -) -> tuple[torch.Tensor, torch.Tensor]: - return torch.ops._C.cutlass_encode_and_reorder_int4b_grouped(b_tensors) - - -if hasattr(torch.ops._C, "cutlass_encode_and_reorder_int4b_grouped"): - - @register_fake("_C::cutlass_encode_and_reorder_int4b_grouped") - def cutlass_encode_and_reorder_int4b_grouped_fake(b: torch.Tensor) -> torch.Tensor: - return torch.empty_like(b, memory_format=torch.contiguous_format) - - -def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - return torch.ops._C.permute_cols(a, perm) - - -if hasattr(torch.ops._C, "permute_cols"): - - @register_fake("_C::permute_cols") - def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - return torch.empty_like(a) - - -# fp4 -def scaled_fp4_quant( - input: torch.Tensor, - input_global_scale: torch.Tensor, - is_sf_swizzled_layout: bool = True, - backend: str = "none", -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP4 and return quantized tensor and scale. - - This function quantizes the last dimension of the given tensor `input`. For - every 16 consecutive elements, a single dynamically computed scaling factor - is shared. This scaling factor is quantized using the `input_global_scale` - and is stored in a swizzled layout (see - https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). - - Args: - input: The input tensor to be quantized to FP4 - input_global_scale: A scalar scaling factor for the entire tensor. - use_8x4_sf_layout: Whether to use the 8x4 or 128x4 layout for the scaling - - Returns: - tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every - two values are packed into a uint8 and float8_e4m3 scaling factors - in the sizzled layout. - """ - assert not current_platform.is_rocm() - assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}." - other_dims = 1 if input.ndim == 1 else -1 - input = input.reshape(other_dims, input.shape[-1]) - m, n = input.shape - block_size = 16 - device = input.device - - assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}." - assert input.dtype in (torch.float16, torch.bfloat16), ( - f"input.dtype needs to be fp16 or bf16 but got {input.dtype}." - ) - - use_8x4_sf_layout = True if "trtllm" in backend and m <= 32 else False # noqa: SIM210 - - if use_8x4_sf_layout: - output, output_scale = flashinfer_quant_nvfp4_8x4_sf_layout( - input, input_global_scale - ) - else: - # Two fp4 values will be packed into an uint8. - output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) - if is_sf_swizzled_layout: - # We use the rounded values to store the swizzled values. Due to the - # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. - # So, we first pad the scales to multiples of 128 and 4. Then, the scales - # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: - # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(m, 128) - scale_n = n // block_size - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( - (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 - ) - else: - output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8) - - torch.ops._C.scaled_fp4_quant( - output, input, output_scale, input_global_scale, is_sf_swizzled_layout - ) - - output_scale = output_scale.view(torch.float8_e4m3fn) - return output, output_scale - - -def scaled_fp4_experts_quant( - input_tensor: torch.Tensor, - input_global_scale: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - topk: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to NVFP4 and return quantized tensor and scale, for - packed MoE Inputs. - Args: - input_tensor: The input tensor to be quantized to NVFP4 - input_global_scale: A scalar scaling factor for the entire tensor. - expert_offsets: The expert offsets tensor - blockscale_offsets: The blockscale offsets tensor - Outputs: - output: The quantized tensor in NVFP4 - output_scales: The blockscale tensor in FP8-E4M3 - """ - assert not current_platform.is_rocm() - assert input_tensor.ndim == 2, ( - f"input.ndim needs to be == 2, but got {input_tensor.ndim}." - ) - - # Control the maximum number of tokens per expert supported by the - # NVFP4 MoE Expert Quantization. This is used to prevent the kernel - # from running out of memory. This value can also be increased to support - # larger models. - MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE - m_numtopk, k = input_tensor.shape - - assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( - f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" - f"{MAX_TOKENS_PER_EXPERT})" - f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" - f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." - ) - scales_k = k // 16 - padded_k = (scales_k + (4 - 1)) // 4 - - # output is uint8 and packed fp4 values - output = torch.empty( - m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 - ) - output_scales = torch.empty( - MAX_TOKENS_PER_EXPERT * topk, - padded_k, - dtype=torch.int32, - device=input_tensor.device, - ) - torch.ops._C.scaled_fp4_experts_quant( - output, - output_scales, - input_tensor, - input_global_scale, - expert_offsets, - blockscale_offsets, - ) - output_scales = output_scales.view(torch.float8_e4m3fn) - return output, output_scales - - -def silu_and_mul_scaled_fp4_experts_quant( - input_tensor: torch.Tensor, - input_global_scale: torch.Tensor, - expert_offsets: torch.Tensor, - blockscale_offsets: torch.Tensor, - topk: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations. - - Args: - input_tensor: The input tensor with gate || up layout [m_topk, k*2] - input_global_scale: A per-expert scaling factor [n_experts] - expert_offsets: The expert offsets tensor [n_experts+1] - blockscale_offsets: The blockscale offsets tensor [n_experts+1] - topk: Number of top-k experts selected - Outputs: - output: The quantized tensor in NVFP4 [m_topk, k/2] - output_scales: The blockscale tensor in FP8-E4M3 - """ - assert not current_platform.is_rocm() - assert input_tensor.ndim == 2, ( - f"input.ndim needs to be == 2, but got {input_tensor.ndim}." - ) - - # Control the maximum number of tokens per expert supported by the - # NVFP4 MoE Expert Quantization. This is used to prevent the kernel - # from running out of memory. This value can also be increased to support - # larger models. - MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE - m_numtopk, k_times_2 = input_tensor.shape - assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)" - k = k_times_2 // 2 - - assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, ( - f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT(" - f"{MAX_TOKENS_PER_EXPERT})" - f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use" - f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value." - ) - scales_k = k // 16 - padded_k = (scales_k + (4 - 1)) // 4 - - # output is uint8 and packed fp4 values - output = torch.empty( - m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8 - ) - output_scales = torch.empty( - MAX_TOKENS_PER_EXPERT * topk, - padded_k, - dtype=torch.int32, - device=input_tensor.device, - ) - torch.ops._C.silu_and_mul_scaled_fp4_experts_quant( - output, - output_scales, - input_tensor, - input_global_scale, - expert_offsets, - blockscale_offsets, - ) - output_scales = output_scales.view(torch.float8_e4m3fn) - return output, output_scales - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: torch.Tensor | None = None, - num_token_padding: int | None = None, - scale_ub: torch.Tensor | None = None, - use_per_token_if_dynamic: bool = False, - output: torch.Tensor | None = None, - group_shape: tuple[int, int] | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 (must be 2D: [M, N]) - scale: Optional scaling factor for the FP8 quantization. Supports: - - 0D or [1]: per-tensor scaling - - 1D: requires explicit group_shape to disambiguate per-channel - vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token) - - 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for - DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128)) - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - group_shape: Optional tuple (group_m, group_n) specifying the group - shape for static quantization. Use -1 for "full extent" (e.g., - (-1, -1) for per-tensor, (-1, 1) for per-channel, etc.) - Required for 1D scales; optional for 2D scales. - - Returns: - tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: tuple[int, int] | torch.Size = input.shape - # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = current_platform.fp8_dtype() - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - if output is None: - output = torch.empty(shape, device=input.device, dtype=out_dtype) - else: - assert num_token_padding is None, "padding not supported if output passed in" - assert output.dtype == out_dtype - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub - ) - else: - scale = torch.empty(1, device=input.device, dtype=torch.float32) - torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) - else: - torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape) - - return output, scale - - -# gptq allspark -def allspark_repack_weight( - qweight: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor | None = None, - has_zp: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format - for Ampere W8A16 Fused Gemm kernel - - Args: - qweight: uint8 weight tensor, original k x n format. - scale: fp16/bf16 weight scale tensor, 1 x n format. - zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. - Must be provided for asymmetric quantization. - has_zp: if use symmetric quantization, has_zp = False. - if use asymmetric quantization, has_zp = True. - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : - rearranged weight, scale, and optionally zero_point. - """ - K = qweight.shape[0] - N = qweight.shape[1] - N_32align = (N + 32 - 1) // 32 * 32 - - qweight_reorder = torch.empty( - (N_32align, K), device=qweight.device, dtype=qweight.dtype - ) - scale_reorder = torch.empty((1, N_32align), device=scale.device, dtype=scale.dtype) - zero_point_reorder = None - if has_zp: - assert zero_point is not None, ( - "zero_point must be provided for asymmetric quantization." - ) - zero_point_reorder = torch.empty( - (1, N_32align), device=zero_point.device, dtype=zero_point.dtype - ) - - torch.ops._C.rearrange_kn_weight_as_n32k16_order( - qweight, - scale, - zero_point, - has_zp, - qweight_reorder, - scale_reorder, - zero_point_reorder, - K, - N, - N_32align, - ) - - return qweight_reorder, scale_reorder, zero_point_reorder - - -def allspark_w8a16_gemm( - a: torch.Tensor, - b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: torch.Tensor | None, - n: int, - group_size: int, - sm_count: int, - sm_version: int, - CUBLAS_M_THRESHOLD: int, - has_zp: bool, - n32k16_reorder: bool, -) -> torch.Tensor: - return torch.ops._C.allspark_w8a16_gemm( - a, - b_qweight, - b_scales, - b_qzeros, - n, - group_size, - sm_count, - sm_version, - CUBLAS_M_THRESHOLD, - has_zp, - n32k16_reorder, - ) - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: torch.Tensor | None = None, - azp: torch.Tensor | None = None, - symmetric: bool = True, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == (azp is None), ( - "azp must only be provided for asymmetric quantization." - ) - torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant( - output, input.contiguous(), input_scales, input_azp - ) - return output, input_scales, input_azp - - -# gguf -def ggml_dequantize( - W: torch.Tensor, quant_type: int, m: int, n: int, dtype: torch.dtype | None -) -> torch.Tensor: - return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) - - -def ggml_mul_mat_vec_a8( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: int, -) -> torch.Tensor: - return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row) - - -def ggml_mul_mat_a8( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: int, -) -> torch.Tensor: - return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) - - -def ggml_moe_a8( - X: torch.Tensor, - W: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_padded: torch.Tensor, - quant_type: int, - row: int, - top_k: int, - tokens: int, -) -> torch.Tensor: - return torch.ops._C.ggml_moe_a8( - X, - W, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - quant_type, - row, - top_k, - tokens, - ) - - -def ggml_moe_a8_vec( - X: torch.Tensor, - W: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - quant_type: int, - row: torch.SymInt, - tokens: torch.SymInt, -) -> torch.Tensor: - return torch.ops._C.ggml_moe_a8_vec(X, W, topk_ids, top_k, quant_type, row, tokens) - - -def ggml_moe_get_block_size(quant_type: int) -> int: - return torch.ops._C.ggml_moe_get_block_size(quant_type) - - -# mamba -def selective_scan_fwd( - u: torch.Tensor, - delta: torch.Tensor, - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - D_: torch.Tensor | None, - z_: torch.Tensor | None, - delta_bias_: torch.Tensor | None, - delta_softplus: bool, - query_start_loc: torch.Tensor | None, - cache_indices: torch.Tensor | None, - has_initial_state: torch.Tensor | None, - ssm_states: torch.Tensor, - pad_slot_id: int, - block_size: int = 1024, - block_idx_first_scheduled_token: torch.Tensor | None = None, - block_idx_last_scheduled_token: torch.Tensor | None = None, - initial_state_idx: torch.Tensor | None = None, - cu_chunk_seqlen: torch.Tensor | None = None, - last_chunk_indices: torch.Tensor | None = None, -): - torch.ops._C.selective_scan_fwd( - u, - delta, - A, - B, - C, - D_, - z_, - delta_bias_, - delta_softplus, - query_start_loc, - cache_indices, - has_initial_state, - ssm_states, - pad_slot_id, - block_size, - block_idx_first_scheduled_token, - block_idx_last_scheduled_token, - initial_state_idx, - cu_chunk_seqlen, - last_chunk_indices, - ) - - -# ROCm skinny gemms -def LLMM1(a: torch.Tensor, b: torch.Tensor, rows_per_block: int) -> torch.Tensor: - return torch.ops._rocm_C.LLMM1(a, b, rows_per_block) - - -def wvSplitK( - a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None -) -> torch.Tensor: - return torch.ops._rocm_C.wvSplitK(a, b, bias, cu_count) - - -def wvSplitKrc( - a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None -) -> torch.Tensor: - return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) - - -def wvSplitKQ( - a: torch.Tensor, - b: torch.Tensor, - out_dtype: torch.dtype, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - cu_count: int, - bias: torch.Tensor = None, -) -> torch.Tensor: - out = torch.empty((b.shape[0], a.shape[0]), dtype=out_dtype, device=b.device) - torch.ops._rocm_C.wvSplitKQ(a, b, bias, out, scale_a, scale_b, cu_count) - return out - - -# moe -def moe_sum(input: torch.Tensor, output: torch.Tensor): - torch.ops._moe_C.moe_sum(input, output) - - -def moe_align_block_size( - topk_ids: torch.Tensor, - num_experts: int, - block_size: int, - sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, - expert_map: torch.Tensor | None = None, -) -> None: - torch.ops._moe_C.moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_token_ids, - experts_ids, - num_tokens_post_pad, - expert_map, - ) - - -def batched_moe_align_block_size( - max_tokens_per_batch: int, - block_size: int, - expert_num_tokens: torch.Tensor, - sorted_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, -) -> None: - torch.ops._moe_C.batched_moe_align_block_size( - max_tokens_per_batch, - block_size, - expert_num_tokens, - sorted_ids, - expert_ids, - num_tokens_post_pad, - ) - - -def moe_lora_align_block_size( - topk_ids: torch.Tensor, - token_lora_mapping: torch.Tensor, - num_experts: int, - block_size: int, - max_loras: int, - max_num_tokens_padded: int, - max_num_m_blocks: int, - sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, - adapter_enabled: torch.Tensor, - lora_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, -) -> None: - torch.ops._moe_C.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - num_experts, - block_size, - max_loras, - max_num_tokens_padded, - max_num_m_blocks, - sorted_token_ids, - experts_ids, - num_tokens_post_pad, - adapter_enabled, - lora_ids, - expert_map, - ) - - -def moe_wna16_gemm( - input: torch.Tensor, - output: torch.Tensor, - b_qweight: torch.Tensor, - b_scales: torch.Tensor, - b_qzeros: torch.Tensor | None, - topk_weights: torch.Tensor | None, - sorted_token_ids: torch.Tensor, - experts_ids: torch.Tensor, - num_tokens_post_pad: torch.Tensor, - top_k: int, - BLOCK_SIZE_M: int, - BLOCK_SIZE_N: int, - BLOCK_SIZE_K: int, - bit: int, -) -> torch.Tensor: - if not current_platform.is_cuda(): - raise NotImplementedError( - "The optimized moe_wna16_gemm kernel is only available on CUDA platforms" - ) - torch.ops._moe_C.moe_wna16_gemm( - input, - output, - b_qweight, - b_scales, - b_qzeros, - topk_weights, - sorted_token_ids, - experts_ids, - num_tokens_post_pad, - top_k, - BLOCK_SIZE_M, - BLOCK_SIZE_N, - BLOCK_SIZE_K, - bit, - ) - - -def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - """bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K).""" - return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight) - - -if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"): - - @register_fake("_moe_C::router_gemm_bf16_fp32") - def router_gemm_bf16_fp32_fake( - input: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - return torch.empty( - input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device - ) - - -def dsv3_router_gemm( - hidden_states: torch.Tensor, - router_weight: torch.Tensor, - output_dtype: torch.dtype, -) -> torch.Tensor: - output = torch.empty( - hidden_states.shape[0], - router_weight.shape[0], - device=hidden_states.device, - dtype=output_dtype, - ) - torch.ops._moe_C.dsv3_router_gemm(output, hidden_states, router_weight) - return output - - -def topk_softmax( - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool = False, - e_score_correction_bias: torch.Tensor | None = None, -) -> None: - torch.ops._moe_C.topk_softmax( - topk_weights, - topk_ids, - token_expert_indices, - gating_output, - renormalize, - e_score_correction_bias, - ) - - -def topk_sigmoid( - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - token_expert_indices: torch.Tensor, - gating_output: torch.Tensor, - renormalize: bool = False, - e_score_correction_bias: torch.Tensor | None = None, -) -> None: - torch.ops._moe_C.topk_sigmoid( - topk_weights, - topk_ids, - token_expert_indices, - gating_output, - renormalize, - e_score_correction_bias, - ) - - -def grouped_topk( - scores: torch.Tensor, - num_expert_group: int, - topk_group: int, - topk: int, - renormalize: bool, - routed_scaling_factor: float, - bias: torch.Tensor, - scoring_func: int = 0, -): - """ - Perform grouped top-k routing for mixture of experts. - - Args: - scores: Raw inputs (logits if scoring_func=1, scores if scoring_func=0) - num_expert_group: Number of expert groups - topk_group: Number of groups to select - topk: Number of experts to select per token - renormalize: Whether to renormalize the output weights - routed_scaling_factor: Scaling factor for routing weights - bias: Bias tensor (e_score_correction_bias). Always fused in kernel. - scoring_func: 0=none (no activation), 1=sigmoid - """ - if not current_platform.is_cuda(): - raise NotImplementedError( - "The fused grouped_topk kernel is only available on CUDA platforms" - ) - return torch.ops._moe_C.grouped_topk( - scores, - num_expert_group, - topk_group, - topk, - renormalize, - routed_scaling_factor, - bias, - scoring_func, - ) - - -def moe_wna16_marlin_gemm( - input: torch.Tensor, - output: torch.Tensor | None, - b_qweight: torch.Tensor, - b_bias: torch.Tensor | None, - b_scales: torch.Tensor, - a_scales: torch.Tensor | None, - global_scale: torch.Tensor | None, - b_qzeros: torch.Tensor | None, - g_idx: torch.Tensor | None, - perm: torch.Tensor | None, - workspace: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_past_padded: torch.Tensor, - topk_weights: torch.Tensor, - moe_block_size: int, - top_k: int, - mul_topk_weights: bool, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - use_atomic_add: bool, - use_fp32_reduce: bool, - is_zp_float: bool, - thread_k: int = -1, - thread_n: int = -1, - blocks_per_sm: int = -1, -) -> torch.Tensor: - return torch.ops._moe_C.moe_wna16_marlin_gemm( - input, - output, - b_qweight, - b_bias, - b_scales, - a_scales, - global_scale, - b_qzeros, - g_idx, - perm, - workspace, - sorted_token_ids, - expert_ids, - num_tokens_past_padded, - topk_weights, - moe_block_size, - top_k, - mul_topk_weights, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - use_atomic_add, - use_fp32_reduce, - is_zp_float, - thread_k, - thread_n, - blocks_per_sm, - ) - - -if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): - - @register_fake("_moe_C::marlin_gemm_moe") - def marlin_gemm_moe_fake( - a: torch.Tensor, - b_q_weights: torch.Tensor, - sorted_ids: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - b_scales: torch.Tensor, - b_zero_points: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - num_experts: int, - topk: int, - moe_block_size: int, - replicate_input: bool, - apply_weights: bool, - ) -> torch.Tensor: - return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) - - @register_fake("_moe_C::moe_wna16_marlin_gemm") - def moe_wna16_marlin_gemm_fake( - input: torch.Tensor, - output: torch.Tensor | None, - b_qweight: torch.Tensor, - b_bias: torch.Tensor | None, - b_scales: torch.Tensor, - a_scales: torch.Tensor | None, - global_scale: torch.Tensor | None, - b_qzeros: torch.Tensor | None, - g_idx: torch.Tensor | None, - perm: torch.Tensor | None, - workspace: torch.Tensor, - sorted_token_ids: torch.Tensor, - expert_ids: torch.Tensor, - num_tokens_past_padded: torch.Tensor, - topk_weights: torch.Tensor, - moe_block_size: int, - top_k: int, - mul_topk_weights: bool, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - use_atomic_add: bool, - use_fp32_reduce: bool, - is_zp_float: bool, - ): - return torch.empty( - (size_m * top_k, size_n), dtype=input.dtype, device=input.device - ) - - -def reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - - -def reshape_and_cache_flash( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: torch.Tensor, - v_scale: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - key_cache, - value_cache, - slot_mapping, - kv_cache_dtype, - k_scale, - v_scale, - ) - - -def concat_and_cache_mla( - kv_c: torch.Tensor, - k_pe: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - scale: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.concat_and_cache_mla( - kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale - ) - - -def concat_and_cache_mla_rope_fused( - positions: torch.Tensor, - q_pe: torch.Tensor, - k_pe: torch.Tensor, - kv_c: torch.Tensor, - cos_sin_cache: torch.Tensor, - is_neox: bool, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - kv_cache_scale: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused( - positions, - q_pe, - k_pe, - kv_c, - cos_sin_cache, - is_neox, - slot_mapping, - kv_cache, - kv_cache_dtype, - kv_cache_scale, - ) - - -def swap_blocks( - src: torch.Tensor, - dst: torch.Tensor, - block_size_in_bytes: int, - block_mapping: torch.Tensor, -) -> None: - """ - Copy specific blocks from one tensor to another. - - This method assumes each of the two input tensors is composed of - consecutive contiguous blocks, of size block_size_in_bytes. - i.e. the memory layout for each tensor is: - [block0] [block1] ... [block N] - - block_mapping determines the subset of blocks to copy of the source tensor, - and their matching destination block number on the destination tensor. - block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2) - where each block_mapping[i] represents a single copy operation, copying - block #block_mapping[i][0] from the source tensor - to block #block_mapping[i][1] on the destination tensor. - block_mapping should have dtype int64. - - The source and the destination tensors can be either on cpu or gpu, - but not both on cpu. - the block mapping tensor must on cpu. - """ - torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping) - - -def convert_fp8( - output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" -) -> None: - torch.ops._C_cache_ops.convert_fp8(output, input, scale, kv_dtype) - - -def gather_and_maybe_dequant_cache( - src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - token_to_seq: torch.Tensor, - num_tokens: int, - kv_cache_dtype: str, - scale: torch.Tensor, - seq_starts: torch.Tensor | None = None, -) -> None: - torch.ops._C_cache_ops.gather_and_maybe_dequant_cache( - src_cache, - dst, - block_table, - cu_seq_lens, - token_to_seq, - num_tokens, - kv_cache_dtype, - scale, - seq_starts, - ) - - -def cp_gather_cache( - src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, - batch_size: int, - seq_starts: torch.Tensor | None = None, -) -> None: - torch.ops._C_cache_ops.cp_gather_cache( - src_cache, dst, block_table, cu_seq_lens, batch_size, seq_starts - ) - - -def cp_gather_and_upconvert_fp8_kv_cache( - src_cache: torch.Tensor, - dst: torch.Tensor, - block_table: torch.Tensor, - seq_lens: torch.Tensor, - workspace_starts: torch.Tensor, - batch_size: int, -) -> None: - """Gather and upconvert FP8 KV cache to BF16 workspace. - - Args: - src_cache: FP8 KV cache [num_blocks, block_size, 656] - dst: BF16 output workspace [total_tokens, 576] - block_table: Block indices [num_reqs, max_blocks] - seq_lens: Sequence lengths [num_reqs] - workspace_starts: Workspace start offsets [num_reqs] - batch_size: Number of requests - """ - torch.ops._C_cache_ops.cp_gather_and_upconvert_fp8_kv_cache( - src_cache, dst, block_table, seq_lens, workspace_starts, batch_size - ) - - -def concat_mla_q( - ql_nope: torch.Tensor, - q_pe: torch.Tensor, - q_out: torch.Tensor, -) -> None: - """Concatenate query nope and rope for MLA/DSA attention. - - Args: - ql_nope: Query nope component [num_tokens, num_heads, nope_dim] - q_pe: Query rope component [num_tokens, num_heads, rope_dim] - q_out: Output tensor [num_tokens, num_heads, nope_dim + rope_dim] - """ - torch.ops._C_cache_ops.concat_mla_q(ql_nope, q_pe, q_out) - - -def indexer_k_quant_and_cache( - k: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - quant_block_size: int, - kv_cache_dtype: str, -) -> None: - torch.ops._C_cache_ops.indexer_k_quant_and_cache( - k, kv_cache, slot_mapping, quant_block_size, kv_cache_dtype - ) - - -def cp_gather_indexer_k_quant_cache( - kv_cache: torch.Tensor, - dst_k: torch.Tensor, - dst_scale: torch.Tensor, - block_table: torch.Tensor, - cu_seq_lens: torch.Tensor, -) -> None: - torch.ops._C_cache_ops.cp_gather_indexer_k_quant_cache( - kv_cache, dst_k, dst_scale, block_table, cu_seq_lens - ) - - -def get_device_attribute(attribute: int, device: int) -> int: - return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) - - -def get_max_shared_memory_per_block_device_attribute(device: int) -> int: - # ruff: noqa: E501 - return torch.ops._C_cuda_utils.get_max_shared_memory_per_block_device_attribute( - device - ) - - -# custom ar -def init_custom_ar( - ipc_tensors: list[torch.Tensor], - rank_data: torch.Tensor, - rank: int, - fully_connected: bool, -) -> int: - return torch.ops._C_custom_ar.init_custom_ar( - ipc_tensors, rank_data, rank, fully_connected - ) - - -def all_reduce( - fa: int, - inp: torch.Tensor, - out: torch.Tensor, - reg_buffer: int, - reg_buffer_sz_bytes: int, -) -> None: - torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) - - -def dispose(fa: int) -> None: - torch.ops._C_custom_ar.dispose(fa) - - -def meta_size() -> int: - return torch.ops._C_custom_ar.meta_size() - - -def register_buffer(fa: int, ipc_tensors: list[int]) -> None: - return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) - - -def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]: - return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) - - -def register_graph_buffers( - fa: int, handles: list[list[int]], offsets: list[list[int]] -) -> None: - torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) - - -def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]: - return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size) - - -def open_mem_handle(mem_handle: torch.Tensor): - return torch.ops._C_custom_ar.open_mem_handle(mem_handle) - - -def free_shared_buffer(ptr: int) -> None: - torch.ops._C_custom_ar.free_shared_buffer(ptr) - - -# quick all reduce -def init_custom_qr(rank: int, world_size: int, qr_max_size: int | None = None) -> int: - return torch.ops._C_custom_ar.init_custom_qr(rank, world_size, qr_max_size) - - -def qr_destroy(fa: int) -> None: - torch.ops._C_custom_ar.qr_destroy(fa) - - -def qr_all_reduce( - fa: int, - inp: torch.Tensor, - out: torch.Tensor, - quant_level: int, - cast_bf2half: bool = False, -) -> None: - torch.ops._C_custom_ar.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half) - - -def qr_get_handle(fa: int) -> torch.Tensor: - return torch.ops._C_custom_ar.qr_get_handle(fa) - - -def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None: - return torch.ops._C_custom_ar.qr_open_handles(fa, handles) - - -def qr_max_size() -> int: - return torch.ops._C_custom_ar.qr_max_size() - - -def get_flash_mla_metadata( - cache_seqlens: torch.Tensor, - num_heads_per_head_k: int, - num_heads_k: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - cache_seqlens: (batch_size), dtype torch.int32. - num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. - - Return: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. - num_splits: (batch_size + 1), dtype torch.int32. - """ - return torch.ops._C.get_flash_mla_metadata( - cache_seqlens, num_heads_per_head_k, num_heads_k - ) - - -def flash_mla_with_kvcache( - q: torch.Tensor, - k_cache: torch.Tensor, - block_table: torch.Tensor, - cache_seqlens: torch.Tensor, - head_dim_v: int, - tile_scheduler_metadata: torch.Tensor, - num_splits: torch.Tensor, - softmax_scale: float | None = None, - causal: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Arguments: - q: (batch_size, seq_len_q, num_heads_q, head_dim). - k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). - block_table: (batch_size, max_num_blocks_per_seq), torch.int32. - cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim). - causal: bool. Whether to apply causal attention mask. - - Return: - out: (batch_size, seq_len_q, num_heads_q, head_dim_v). - softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. - """ - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - out, softmax_lse = torch.ops._C.flash_mla_fwd_kvcache( - q, - k_cache, - None, - head_dim_v, - cache_seqlens, - block_table, - softmax_scale, - causal, - tile_scheduler_metadata, - num_splits, - ) - return out, softmax_lse - - -def sm100_cutlass_mla_decode( - out: torch.Tensor, - lse: torch.Tensor, - q_nope: torch.Tensor, - q_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - seq_lens: torch.Tensor, - page_table: torch.Tensor, - workspace: torch.Tensor, - scale: float, - num_kv_splits: int, -) -> torch.Tensor: - torch.ops._C.sm100_cutlass_mla_decode( - out, - lse, - q_nope, - q_pe, - kv_c_and_k_pe_cache, - seq_lens, - page_table, - workspace, - scale, - num_kv_splits, - ) - return out - - -def sm100_cutlass_mla_get_workspace_size( - max_seq_len: int, num_batches: int, sm_count: int, num_kv_splits: int -) -> int: - return torch.ops._C.sm100_cutlass_mla_get_workspace_size( - max_seq_len, num_batches, sm_count, num_kv_splits - ) - - -def dsv3_fused_a_gemm( - output: torch.Tensor, - mat_a: torch.Tensor, - mat_b: torch.Tensor, -) -> None: - """DeepSeek V3 fused A GEMM (SM 9.0+, bf16 only, 1-16 tokens). - - Computes output = mat_a @ mat_b.T where: - mat_a: [num_tokens, 7168] row-major bf16 (hidden states) - mat_b: [7168, 2112] column-major bf16 (weight transposed) - output: [num_tokens, 2112] row-major bf16 - - Optimized for the DeepSeek V2/V3 QKV A-projection at small batch sizes. - Requires SM 9.0+ (Hopper). - """ - torch.ops._C.dsv3_fused_a_gemm(output, mat_a, mat_b) - - -if hasattr(torch.ops._C, "weight_packed_linear"): - - @register_fake("_C::weight_packed_linear") - def weight_packed_linear_fake( - mat1: torch.Tensor, - mat2: torch.Tensor, - bias: torch.Tensor | None, - is_vnni: bool, - ) -> torch.Tensor: - return torch.empty( - (mat1.size(0), mat2.size(0)), dtype=mat1.dtype, device=mat2.device - ) - - -if hasattr(torch.ops._C, "fused_experts_cpu"): - - @register_fake("_C::fused_experts_cpu") - def fused_experts_cpu_fake( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace: bool, - use_int8_w8a8: bool, - use_fp8_w8a16: bool, - w1_scale: torch.Tensor | None, - w2_scale: torch.Tensor | None, - block_size: list[int] | None, - a1_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - is_vnni: bool, - ) -> torch.Tensor: - return torch.empty_like(hidden_states) - - -if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): - - @register_fake("_C::int8_scaled_mm_with_quant") - def int8_scaled_mm_with_quant_fake( - mat1: torch.Tensor, - mat2: torch.Tensor, - scales2: torch.Tensor, - bias: torch.Tensor | None, - out_dtype: torch.dtype, - is_vnni: bool, - ) -> torch.Tensor: - M = mat1.size(0) - N = mat2.size(0) - return torch.empty((M, N), dtype=out_dtype) - - -class CPUDNNLGEMMHandler: - def __init__(self) -> None: - self.handler_tensor: torch.Tensor | None = None - self.n = -1 - self.k = -1 - - def __del__(self): - if self.handler_tensor is not None: - torch.ops._C.release_dnnl_matmul_handler(self.handler_tensor.item()) - - -_supports_onednn = bool(hasattr(torch.ops._C, "create_onednn_mm_handler")) - - -def is_onednn_acl_supported(): - return torch.ops._C.is_onednn_acl_supported() - - -def create_onednn_mm( - weight: torch.Tensor, # [K, N] - primitive_cache_size: int = 128, -) -> CPUDNNLGEMMHandler: - handler = CPUDNNLGEMMHandler() - handler.k, handler.n = weight.size() - # store the handler pointer in a tensor it doesn't get inlined - handler.handler_tensor = torch.tensor( - torch.ops._C.create_onednn_mm_handler(weight, primitive_cache_size), - dtype=torch.int64, - ) - return handler - - -def onednn_mm( - dnnl_handler: CPUDNNLGEMMHandler, - x: torch.Tensor, - bias: torch.Tensor | None, -) -> torch.Tensor: - output = torch.empty((*x.shape[0:-1], dnnl_handler.n), dtype=x.dtype) - torch.ops._C.onednn_mm( - output, x.reshape(-1, dnnl_handler.k), bias, dnnl_handler.handler_tensor - ) - - return output - - -def create_onednn_scaled_mm( - weight: torch.Tensor, # [K, N] - weight_scales: torch.Tensor, - output_type: torch.dtype, - dynamic_quant: bool, - use_azp: bool, - primitive_cache_size: int = 128, -) -> CPUDNNLGEMMHandler: - handler = CPUDNNLGEMMHandler() - handler.k, handler.n = weight.size() - # store the handler pointer in a tensor so it doesn't get inlined - handler.handler_tensor = torch.tensor( - torch.ops._C.create_onednn_scaled_mm_handler( - weight, - weight_scales, - output_type, - dynamic_quant, - use_azp, - primitive_cache_size, - ), - dtype=torch.int64, - ) - return handler - - -def onednn_scaled_int8_quant( - input: torch.Tensor, - scale: torch.Tensor | None = None, - azp: torch.Tensor | None = None, - symmetric: bool = True, ): """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - token_num = input.numel() // input.shape[-1] - input = input.view((token_num, input.shape[-1])) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == (azp is None), ( - "azp must only be provided for asymmetric quantization." - ) - torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((token_num, 1), device=input.device, dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp - - -def onednn_scaled_mm( - dnnl_handler: CPUDNNLGEMMHandler, - x: torch.Tensor, - output: torch.Tensor, - input_scale: torch.Tensor | None, - input_zp: torch.Tensor | None, - input_zp_adj: torch.Tensor | None, - bias: torch.Tensor | None, -) -> torch.Tensor: - torch.ops._C.onednn_scaled_mm( - output, - x, - input_scale, - input_zp, - input_zp_adj, - bias, - dnnl_handler.handler_tensor, - ) - - return output - - -def cpu_attn_get_scheduler_metadata( - num_reqs: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - seq_lens: torch.Tensor, - dtype: torch.dtype, - query_start_loc: torch.Tensor, - causal: bool, - sliding_window_size: int, - isa: str, - enable_kv_split: bool, -) -> torch.Tensor: - scheduler_metadata = torch.ops._C.get_scheduler_metadata( - num_reqs, - num_heads, - num_kv_heads, - head_dim, - seq_lens, - dtype, - query_start_loc, - causal, - sliding_window_size, - isa, - enable_kv_split, - ) - return scheduler_metadata - - -def cpu_attn_reshape_and_cache( - key: torch.Tensor, - value: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - slot_mapping: torch.Tensor, - isa: str, -) -> None: - torch.ops._C.cpu_attn_reshape_and_cache( - key, - value, - key_cache, - value_cache, - slot_mapping, - isa, - ) - - -def cpu_attention_with_kv_cache( - query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - output: torch.Tensor, - query_start_loc: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - causal: bool, - alibi_slopes: torch.Tensor | None, - sliding_window: tuple[int, int], - block_table: torch.Tensor, - softcap: float, - scheduler_metadata: torch.Tensor, - s_aux: torch.Tensor | None, -) -> None: - torch.ops._C.cpu_attention_with_kv_cache( - query, - key_cache, - value_cache, - output, - query_start_loc, - seq_lens, - scale, - causal, - alibi_slopes, - sliding_window[0], - sliding_window[1], - block_table, - softcap, - scheduler_metadata, - s_aux, - ) - - -def cpu_gemm_wna16( - input: torch.Tensor, - q_weight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor | None, - g_idx: torch.Tensor | None, - bias: torch.Tensor | None, - pack_factor: int, - isa_hint: str, -) -> torch.Tensor: - output = torch.empty((input.size(0), scales.size(1)), dtype=input.dtype) - torch.ops._C.cpu_gemm_wna16( - input, - q_weight, - output, - scales, - zeros, - g_idx, - bias, - pack_factor, - isa_hint, - ) - return output - - -def cpu_prepack_moe_weight( - weight: torch.Tensor, - isa: str, -) -> torch.Tensor: - output = torch.empty_like(weight) - torch.ops._C.prepack_moe_weight(weight, output, isa) - return output - - -def cpu_fused_moe( - input: torch.Tensor, - w13: torch.Tensor, - w2: torch.Tensor, - w13_bias: torch.Tensor | None, - w2_bias: torch.Tensor | None, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - act: str, - isa: str, - skip_weighted: bool = False, -) -> torch.Tensor: - output = torch.empty_like(input) - torch.ops._C.cpu_fused_moe( - output, - input, - w13, - w2, - w13_bias, - w2_bias, - topk_weights, - topk_ids, - skip_weighted, - act, - isa, - ) - return output - - -if hasattr(torch.ops._qutlass_C, "matmul_mxf4_bf16_tn"): - - @register_fake("_qutlass_C::matmul_mxf4_bf16_tn") - def _fake_matmul_mxf4_bf16_tn( - a: torch.Tensor, - b: torch.Tensor, - a_sf: torch.Tensor, - b_sf: torch.Tensor, - alpha: torch.Tensor, - ): - return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) - - -def matmul_mxf4_bf16_tn( - a: torch.Tensor, - b: torch.Tensor, - a_sf: torch.Tensor, - b_sf: torch.Tensor, - alpha: torch.Tensor, -) -> torch.Tensor: - return torch.ops._qutlass_C.matmul_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) - - -if hasattr(torch.ops._qutlass_C, "matmul_ada_mxf4_bf16_tn"): - - @register_fake("_qutlass_C::matmul_ada_mxf4_bf16_tn") - def _fake_matmul_ada_mxf4_bf16_tn( - a: torch.Tensor, - b: torch.Tensor, - a_sf: torch.Tensor, - b_sf: torch.Tensor, - alpha: torch.Tensor, - ): - return a.new_empty(*a.shape[:-1], b.shape[0], dtype=torch.bfloat16) - - -def matmul_ada_mxf4_bf16_tn( - a: torch.Tensor, - b: torch.Tensor, - a_sf: torch.Tensor, - b_sf: torch.Tensor, - alpha: torch.Tensor, -) -> torch.Tensor: - return torch.ops._qutlass_C.matmul_ada_mxf4_bf16_tn(a, b, a_sf, b_sf, alpha) - - -if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxQuest"): - - @register_fake("_qutlass_C::fusedQuantizeMxQuest") - def _fake_fused_quantize_mx_quest( - a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor - ): - return xh_e2m1, xh_e8m0 - - -if hasattr(torch.ops._qutlass_C, "fusedQuantizeMxAbsMax"): - - @register_fake("_qutlass_C::fusedQuantizeMxAbsMax") - def _fake_fused_quantize_mx_absmax( - a: torch.Tensor, b: torch.Tensor, xh_e2m1: torch.Tensor, xh_e8m0: torch.Tensor - ): - return xh_e2m1, xh_e8m0 - - -def fusedQuantizeMx( - a: torch.Tensor, b: torch.Tensor, *, method: Literal["quest", "abs_max"] = "quest" -) -> tuple[torch.Tensor, torch.Tensor]: - if a.dim() == 0: - raise ValueError("`a` must have at least 1 dimension.") - if a.size(-1) % 32 != 0: - raise ValueError(f"last dim of `a` must be divisible by 32, got {a.size(-1)}.") - if b.device != a.device: - raise ValueError("`a` and `b` must be on the same device.") - - xh_e2m1 = torch.empty( - *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device - ) - - rows, cols = a.numel() // a.size(-1), a.size(-1) // 32 - n_row_blocks = cdiv(rows, 128) - n_col_blocks = cdiv(cols, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - - xh_e8m0 = torch.empty( - padded_rows, padded_cols, dtype=torch.float8_e8m0fnu, device=a.device - ) - - if not hasattr(torch.ops, "_qutlass_C"): - raise RuntimeError( - "The `_qutlass_C` extension is not loaded. " - "Make sure your custom op library is imported before calling fusedQuantizeMx." - ) - - if method == "quest": - return torch.ops._qutlass_C.fusedQuantizeMxQuest(a, b, xh_e2m1, xh_e8m0) - elif method == "abs_max": - return torch.ops._qutlass_C.fusedQuantizeMxAbsMax(a, b, xh_e2m1, xh_e8m0) - else: - raise ValueError(f"invalid method {method!r}, must be 'quest' or 'abs_max'") - - -if hasattr(torch.ops._qutlass_C, "fusedQuantizeNv"): - - @register_fake("_qutlass_C::fusedQuantizeNv") - def _fake_fused_quantize_nv( - a: torch.Tensor, - b: torch.Tensor, - xh_e2m1: torch.Tensor, - xh_e4m3: torch.Tensor, - global_scale: torch.Tensor, - ): - return xh_e2m1, xh_e4m3 - - -def fusedQuantizeNv( - a: torch.Tensor, b: torch.Tensor, global_scale: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - xh_e2m1 = torch.empty( - *a.shape[:-1], a.size(-1) // 2, dtype=torch.uint8, device=a.device - ) - - rows, cols = a.numel() // a.size(-1), a.size(-1) // 16 - n_row_blocks = cdiv(rows, 128) - n_col_blocks = cdiv(cols, 4) - padded_rows = n_row_blocks * 128 - padded_cols = n_col_blocks * 4 - xh_e4m3 = torch.empty( - padded_rows, padded_cols, dtype=torch.float8_e4m3fn, device=a.device - ) - - return torch.ops._qutlass_C.fusedQuantizeNv(a, b, xh_e2m1, xh_e4m3, global_scale) - - -def hadacore_transform(x: torch.Tensor, inplace: bool = True) -> torch.Tensor: + Optimized paged attention with dynamic GPU/CPU device compatibility. """ - Perform Hadamard transforms using [Hadacore](https://arxiv.org/abs/2412.08832) - kernels. Note that these kernels exploit the recursive properties of - Sylvester Hadamards, and therefore do not require transform weight data + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device}") - Note that sylvester hadamard transforms are also symmetric, which means that - this function is also applies the (transpose <=> inverse) transform. - - :param x: value to be transformed inplace - :param inplace: modify value in place - :return: value after transformation - """ - return torch.ops._C.hadacore_transform(x, inplace) + # Ensure tensors are on the correct device + query, key_cache, value_cache = query.to(device), key_cache.to(device), value_cache.to(device) + k_scale, v_scale = k_scale.to(device), v_scale.to(device) + out = torch.zeros_like(query, device=device) # Ensuring output is on device -if hasattr(torch.ops._C, "hadacore_transform"): + # Existing implementation + pass - @register_fake("_C::hadacore_transform") - def _hadacore_transform_fake(x: torch.Tensor, inplace: bool) -> torch.Tensor: - return torch.empty_like(x) if not inplace else x + return out \ No newline at end of file diff --git a/vllm/config/attention.py b/vllm/config/attention.py index e05544f08e10..85673f384adf 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -30,14 +30,14 @@ class AttentionConfig: use_cudnn_prefill: bool = False """Whether to use cudnn prefill.""" - use_trtllm_ragged_deepseek_prefill: bool = True + use_trtllm_ragged_deepseek_prefill: bool = False """Whether to use TRTLLM ragged deepseek prefill.""" use_trtllm_attention: bool | None = None """If set to True/False, use or don't use the TRTLLM attention backend in flashinfer. If None, auto-detect the attention backend in flashinfer.""" - disable_flashinfer_prefill: bool = False + disable_flashinfer_prefill: bool = True """Whether to disable flashinfer prefill.""" disable_flashinfer_q_quantization: bool = False diff --git a/vllm/entrypoints/openai/responses/harmony.py b/vllm/entrypoints/openai/responses/harmony.py index 460f310926ad..faab2f7f4cc7 100644 --- a/vllm/entrypoints/openai/responses/harmony.py +++ b/vllm/entrypoints/openai/responses/harmony.py @@ -138,8 +138,12 @@ def _parse_chat_format_message(chat_msg: dict) -> list[Message]: def response_input_to_harmony( response_msg: ResponseInputOutputItem, prev_responses: list[ResponseOutputItem | ResponseReasoningItem], -) -> Message: - """Convert a single ResponseInputOutputItem into a Harmony Message.""" +) -> Message | None: + """Convert a single ResponseInputOutputItem into a Harmony Message. + + Returns None for reasoning items with empty or absent content so + the caller can skip them. + """ if not isinstance(response_msg, dict): response_msg = response_msg.model_dump() if "type" not in response_msg or response_msg["type"] == "message": @@ -172,9 +176,13 @@ def response_input_to_harmony( response_msg["output"], ) elif response_msg["type"] == "reasoning": - content = response_msg["content"] - assert len(content) == 1 - msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + content = response_msg.get("content") + if content and len(content) >= 1: + reasoning_text = "\n".join(item["text"] for item in content) + msg = Message.from_role_and_content(Role.ASSISTANT, reasoning_text) + msg = msg.with_channel("analysis") + else: + return None elif response_msg["type"] == "function_call": msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"]) msg = msg.with_channel("commentary") diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index a7eaccd83db7..6d0041813e35 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -1086,7 +1086,7 @@ def _construct_input_messages_with_harmony( prev_outputs = [] for response_msg in request.input: new_msg = response_input_to_harmony(response_msg, prev_outputs) - if new_msg.author.role != "system": + if new_msg is not None and new_msg.author.role != "system": messages.append(new_msg) # User passes in a tool call request and its output. We need diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index 36ccc649f930..3794bde4101e 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: @functools.cache def use_flashinfer_prefill() -> bool: - # For blackwell default to flashinfer prefill if it's available since - # it is faster than FA2. from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() @@ -2154,13 +2152,16 @@ def __init__( # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 + # not support different headdims. + # FA3 on Hopper (SM90) and FA4 natively handle diff headdims. device_capability = current_platform.get_device_capability() self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and device_capability is not None - and device_capability[0] == 9 + ( + self.vllm_flash_attn_version == 3 + and device_capability is not None + and device_capability[0] == 9 + ) + or self.vllm_flash_attn_version == 4 ) self.dcp_world_size: int = -1 diff --git a/vllm/tokenizers/kimi_audio.py b/vllm/tokenizers/kimi_audio.py index ef3f9efb8326..d2b0a2a557ef 100644 --- a/vllm/tokenizers/kimi_audio.py +++ b/vllm/tokenizers/kimi_audio.py @@ -4,6 +4,7 @@ import contextlib import json +from collections.abc import Sequence from pathlib import Path from typing import Any, overload @@ -299,7 +300,9 @@ def encode( tokens = self._maybe_truncate(tokens, max_length) return tokens - def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: + def decode( + self, ids: Sequence[int] | int, skip_special_tokens: bool = False + ) -> str: """Decode token IDs to text, optionally skipping special tokens.""" if isinstance(ids, int): ids = [ids] @@ -321,7 +324,7 @@ def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: return [self._token_to_id.get(token, self._unk_token_id) for token in tokens] def convert_ids_to_tokens( - self, ids: list[int], skip_special_tokens: bool = False + self, ids: Sequence[int], skip_special_tokens: bool = False ) -> list[str]: tokens = [] for token_id in ids: diff --git a/vllm/tool_parsers/abstract_tool_parser.py b/vllm/tool_parsers/abstract_tool_parser.py index 75cffd3297f6..81ee4ea671e6 100644 --- a/vllm/tool_parsers/abstract_tool_parser.py +++ b/vllm/tool_parsers/abstract_tool_parser.py @@ -68,7 +68,7 @@ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionReques # tool_choice: "Forced Function" or "required" will override # structured output json settings to make tool calling work correctly request.structured_outputs = StructuredOutputsParams( - json=json_schema_from_tool + json=json_schema_from_tool # type: ignore[call-arg] ) request.response_format = None if isinstance(request, ResponsesRequest): diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 20502cbf0feb..cd8c46d032c0 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -125,11 +125,14 @@ def get_flash_attn_version( # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict # supported head dimensions. # See: https://github.com/Dao-AILab/flash-attention/issues/1959 + # Exception: hdim 192 is supported for MLA's diff-headdim case + # (qk=192, v=128), added upstream in commits 1a15733e/1b36ab19. if ( fa_version == 4 and device_capability.major >= 10 and head_size is not None and head_size > 128 + and head_size != 192 ): logger.warning_once( "FA4 on Blackwell does not support head_size=%d due to TMEM " diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 2da2bbd6bb5a..ca9f7452e311 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -32,6 +32,8 @@ class TritonMLABackend(MLACommonBackend): supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ "auto", "bfloat16", + "fp8", + "fp8_e4m3", ] @classmethod @@ -108,10 +110,11 @@ def __init__( "TritonMLAImpl" ) + # For FP8 KV cache, we dequantize to BF16 on load inside the + # Triton kernel. Tell the common layer not to quantize queries + # to FP8 — we handle FP8 KV cache with BF16 queries (Mode 1). if is_quantized_kv_cache(self.kv_cache_dtype): - raise NotImplementedError( - "TritonMLA V1 with FP8 KV cache not yet supported" - ) + self.supports_quant_query_input = False def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs @@ -135,9 +138,6 @@ def forward_mqa( assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None - if self.kv_cache_dtype.startswith("fp8"): - raise NotImplementedError("FP8 Triton MLA not yet supported") - if type(q) is tuple: q = torch.cat(q, dim=-1) @@ -171,7 +171,8 @@ def forward_mqa( kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank] PAGE_SIZE = kv_c_and_k_pe_cache.size(1) - # Run MQA + # Run MQA — always pass layer scales. When KV cache is + # BF16 the kernel's `if dtype.is_fp8()` check is a no-op. decode_attention_fwd( q, kv_c_and_k_pe_cache, @@ -184,6 +185,8 @@ def forward_mqa( num_kv_splits, self.scale, PAGE_SIZE, + k_scale=layer._k_scale, + v_scale=layer._v_scale, ) return o, lse diff --git a/vllm/v1/attention/ops/triton_decode_attention.py b/vllm/v1/attention/ops/triton_decode_attention.py index 1ed9698c507a..63263bc92e24 100644 --- a/vllm/v1/attention/ops/triton_decode_attention.py +++ b/vllm/v1/attention/ops/triton_decode_attention.py @@ -31,6 +31,7 @@ import logging +import torch from packaging import version from vllm.platforms import current_platform @@ -74,6 +75,8 @@ def _fwd_kernel_stage1( stride_mid_ob, stride_mid_oh, stride_mid_os, + k_scale, + v_scale, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DV: tl.constexpr, @@ -109,6 +112,8 @@ def _fwd_kernel_stage1( acc = tl.zeros([BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + ks = tl.load(k_scale) + vs = tl.load(v_scale) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( @@ -129,6 +134,8 @@ def _fwd_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), other=0.0, ) + if k.dtype.is_fp8(): + k = (k.to(tl.float32) * ks).to(q.dtype) qk = tl.sum(q[None, :] * k, 1) qk *= sm_scale @@ -147,6 +154,8 @@ def _fwd_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) + if v.dtype.is_fp8(): + v = (v.to(tl.float32) * vs).to(q.dtype) n_e_max = tl.maximum(tl.max(qk, 0), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -194,6 +203,8 @@ def _decode_att_m_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ): BLOCK = 64 if not is_hip_ else 8 @@ -231,6 +242,8 @@ def _decode_att_m_fwd( att_out.stride(0), att_out.stride(1), att_out.stride(2), + k_scale, + v_scale, kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DV=BLOCK_DV, @@ -264,6 +277,8 @@ def _fwd_grouped_kernel_stage1( stride_mid_ob, stride_mid_oh, stride_mid_os, + k_scale, + v_scale, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -316,6 +331,8 @@ def _fwd_grouped_kernel_stage1( acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) if split_kv_end > split_kv_start: + ks = tl.load(k_scale) + vs = tl.load(v_scale) for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_page_number = tl.load( @@ -336,6 +353,8 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, ) + if k.dtype.is_fp8(): + k = (k.to(tl.float32) * ks).to(q.dtype) qk = tl.dot(q, k.to(q.dtype)) if BLOCK_DPE > 0: offs_buf_kpe = ( @@ -348,6 +367,8 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), other=0.0, ) + if kpe.dtype.is_fp8(): + kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype) qk += tl.dot(qpe, kpe.to(qpe.dtype)) qk *= sm_scale @@ -368,6 +389,8 @@ def _fwd_grouped_kernel_stage1( mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), other=0.0, ) + if v.dtype.is_fp8(): + v = (v.to(tl.float32) * vs).to(q.dtype) n_e_max = tl.maximum(tl.max(qk, 1), e_max) re_scale = tl.exp(e_max - n_e_max) @@ -416,6 +439,8 @@ def _decode_grouped_att_m_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ): BLOCK = 32 Lk = k_buffer.shape[-1] @@ -473,6 +498,8 @@ def _decode_grouped_att_m_fwd( att_out.stride(0), att_out.stride(1), att_out.stride(2), + k_scale, + v_scale, kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, @@ -609,6 +636,8 @@ def decode_attention_fwd_normal( sm_scale, page_size, logit_cap=0.0, + k_scale=None, + v_scale=None, ): _decode_att_m_fwd( q, @@ -621,6 +650,8 @@ def decode_attention_fwd_normal( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -640,6 +671,8 @@ def decode_attention_fwd_grouped( sm_scale, page_size, logit_cap=0.0, + k_scale=None, + v_scale=None, ): _decode_grouped_att_m_fwd( q, @@ -652,6 +685,8 @@ def decode_attention_fwd_grouped( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ) _decode_softmax_reducev_fwd( attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits @@ -671,8 +706,16 @@ def decode_attention_fwd( sm_scale, page_size=1, logit_cap=0.0, + k_scale=None, + v_scale=None, ): assert num_kv_splits == attn_logits.shape[2] + + if k_scale is None: + k_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device) + if v_scale is None: + v_scale = torch.tensor(1.0, dtype=torch.float32, device=q.device) + kv_group_num = q.shape[1] // v_buffer.shape[-2] if kv_group_num == 1: @@ -690,6 +733,8 @@ def decode_attention_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, ) else: # GQA/MQA/MLA @@ -706,4 +751,6 @@ def decode_attention_fwd( sm_scale, page_size, logit_cap, + k_scale, + v_scale, )