diff --git a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py index 1975eb61b25d..e4bf13ba8597 100644 --- a/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py +++ b/tests/model_executor/model_loader/fastsafetensors_loader/test_weight_utils.py @@ -8,6 +8,9 @@ import pytest import torch +import vllm.model_executor.model_loader.weight_utils as weight_utils +from vllm.config.load import LoadConfig +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_weights_from_hf, fastsafetensors_weights_iterator, @@ -16,6 +19,77 @@ from vllm.platforms import current_platform +def test_default_loader_filters_fastsafetensors_before_materializing(monkeypatch): + class FakeProcessGroup: + def size(self): + return 1 + + class FakeFileBuffer: + def __init__(self): + self.key_to_rank_lidx = { + "model.layers.0.self_attn.q_proj.weight": (0, 0), + "model.layers.0.mlp.experts.0.gate_proj.weight": (0, 1), + "model.layers.0.mlp.experts.1.gate_proj.weight": (0, 2), + "model.mtp.0.weight": (0, 3), + } + self.loaded_keys: list[str] = [] + self.closed = False + + def get_tensor(self, key: str): + self.loaded_keys.append(key) + return torch.tensor([len(self.loaded_keys)]) + + def close(self): + self.closed = True + + class FakeLoader: + def __init__(self, file_buffer): + self.file_buffer = file_buffer + self.closed = False + + def copy_files_to_device(self): + return self.file_buffer + + def close(self): + self.closed = True + + file_buffer = FakeFileBuffer() + loader = FakeLoader(file_buffer) + + model_loader = DefaultModelLoader(LoadConfig(load_format="fastsafetensors")) + model_loader.local_expert_ids = {0} + monkeypatch.setattr( + model_loader, + "_prepare_weights", + lambda *_args: ("/weights", ["model.safetensors"], True), + ) + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) + monkeypatch.setattr(weight_utils, "SingleGroup", FakeProcessGroup) + monkeypatch.setattr( + weight_utils, + "_init_fastsafetensors_loader", + lambda *_args, **_kwargs: loader, + ) + + loaded = dict( + model_loader._get_weights_iterator( + DefaultModelLoader.Source("model", revision=None), + weight_name_filter=lambda name: "model.mtp." in name, + ) + ) + + assert set(loaded) == { + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + } + assert file_buffer.loaded_keys == [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.mlp.experts.0.gate_proj.weight", + ] + assert file_buffer.closed + assert loader.closed + + @pytest.mark.skipif( not current_platform.is_cuda_alike(), reason="fastsafetensors requires NVIDIA/AMD GPUs", diff --git a/tests/model_executor/model_loader/test_ep_weight_filter.py b/tests/model_executor/model_loader/test_ep_weight_filter.py index 2ac38192a4b0..d032cd569019 100644 --- a/tests/model_executor/model_loader/test_ep_weight_filter.py +++ b/tests/model_executor/model_loader/test_ep_weight_filter.py @@ -319,6 +319,20 @@ def test_ep2_rank0_gets_half_experts(self, synthetic_moe_files): assert "model.layers.0.input_layernorm.weight" in loaded assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + def test_weight_name_filter_skips_dense_weights(self, synthetic_moe_files): + files, _ = synthetic_moe_files + loaded = dict( + safetensors_weights_iterator( + files, + False, + weight_name_filter=lambda name: "self_attn.q_proj" in name, + ) + ) + + assert "model.layers.0.self_attn.q_proj.weight" not in loaded + assert "model.embed_tokens.weight" in loaded + assert "model.layers.0.mlp.shared_experts.gate_proj.weight" in loaded + def test_ep2_rank1_gets_other_half(self, synthetic_moe_files): files, expected = synthetic_moe_files local_ids = compute_local_expert_ids(8, ep_size=2, ep_rank=1) diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py new file mode 100644 index 000000000000..8cdaea82862e --- /dev/null +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -0,0 +1,288 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.utils.deep_gemm as deep_gemm_utils +from vllm.model_executor.layers.sparse_attn_indexer import ( + _decode_logits_width, + _decode_topk_logits_width, + _sparse_indexer_requires_deep_gemm, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.mla import indexer as mla_indexer +from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + +def _make_indexer_kv_cache( + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, +) -> torch.Tensor: + num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape + assert num_kv_heads == 1 + fused_kv = torch.empty( + num_blocks, + block_size, + 1, + head_dim + torch.float32.itemsize, + device=kv_fp8.device, + dtype=torch.uint8, + ) + block_stride = fused_kv.stride(0) + kv_values = torch.as_strided( + fused_kv, + size=kv_fp8.shape, + stride=(block_stride, head_dim, head_dim, 1), + ) + kv_scales = torch.as_strided( + fused_kv, + size=(num_blocks, block_size, 1, torch.float32.itemsize), + stride=(block_stride, torch.float32.itemsize, torch.float32.itemsize, 1), + storage_offset=block_size * head_dim, + ) + kv_values.copy_(kv_fp8.view(torch.uint8)) + kv_scales.copy_(kv_scale.contiguous().view(torch.uint8)) + return fused_kv + + +def _reference_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + batch_size, next_n, _, _ = q_fp8.shape + _, block_size, _, _ = kv_fp8.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + q = q_fp8.float() + kv = kv_fp8.float() * kv_scale.float() + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = min( + int(context_lens[batch_idx, next_idx].item()), + max_model_len, + ) + for token_idx in range(context_len): + block_idx = block_tables[batch_idx, token_idx // block_size] + block_offset = token_idx % block_size + k = kv[block_idx, block_offset, 0] + scores = (q[batch_idx, next_idx] * k).sum(dim=-1).relu() + logits[row, token_idx] = (scores * weights[row]).sum() + return logits + + +def test_decode_logits_width_uses_active_context_bound(): + assert _decode_logits_width(262144, 1024) == 1024 + assert _decode_logits_width(4096, 8192) == 4096 + assert _decode_logits_width(4096, 0) == 4096 + assert _decode_logits_width(0, 1024) == 0 + + +def test_decode_topk_logits_width_keeps_topk_kernel_width(): + assert _decode_topk_logits_width(262144, 1024, 512) == 1024 + assert _decode_topk_logits_width(262144, 128, 512) == 512 + assert _decode_topk_logits_width(300, 128, 512) == 300 + assert _decode_topk_logits_width(0, 128, 512) == 0 + + +def test_sm120_sparse_indexer_does_not_require_deep_gemm(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + + assert _sparse_indexer_requires_deep_gemm() is False + + +def test_non_sm120_cuda_sparse_indexer_still_requires_deep_gemm(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: False, + ) + + assert _sparse_indexer_requires_deep_gemm() is True + + +def test_sm120_mla_indexer_skips_deep_gemm_scheduler_metadata(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True) + + assert not mla_indexer._uses_deep_gemm_scheduler_metadata() + + +def test_cuda_mla_indexer_uses_deep_gemm_scheduler_metadata_off_sm12x(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: False, + ) + monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True) + + assert mla_indexer._uses_deep_gemm_scheduler_metadata() + + +def test_sm120_fp8_mqa_fallbacks_do_not_initialize_deep_gemm(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + + def fail_lazy_init(): + raise AssertionError("SM120 FP8 MQA should not initialize DeepGEMM") + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", fail_lazy_init) + + mqa_result = torch.empty(1) + paged_result = torch.empty(1) + calls = [] + + def fake_mqa_fallback(*args, **kwargs): + calls.append("mqa") + return mqa_result + + def fake_paged_fallback(*args, **kwargs): + calls.append("paged") + return paged_result + + monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_sm12x", fake_mqa_fallback) + monkeypatch.setattr( + deep_gemm_utils, "_fp8_paged_mqa_logits_sm12x", fake_paged_fallback + ) + + assert ( + deep_gemm_utils.fp8_fp4_mqa_logits( + (torch.empty(1, 1, 1), None), + (torch.empty(1, 1), torch.empty(1)), + torch.empty(1, 1), + torch.empty(1, dtype=torch.int32), + torch.empty(1, dtype=torch.int32), + clean_logits=False, + ) + is mqa_result + ) + assert ( + deep_gemm_utils.fp8_fp4_paged_mqa_logits( + (torch.empty(1, 1, 1, 1), None), + torch.empty(1, 1, 1, 5, dtype=torch.uint8), + torch.empty(1, 1), + torch.empty(1, 1, dtype=torch.int32), + torch.empty(1, 1, dtype=torch.int32), + torch.empty(1, dtype=torch.int32), + max_model_len=1, + clean_logits=False, + ) + is paged_result + ) + assert calls == ["mqa", "paged"] + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(7) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 64, 16 + active_max_len = 13 + topk_tokens = 6 + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr( + sm12x_deep_gemm_fallbacks, + "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", + 7, + ) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_indexer_kv_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor( + [[7, active_max_len], [9, 12]], device="cuda", dtype=torch.int32 + ) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + full_width_topk = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + truncated_width_topk = torch.empty_like(full_width_topk) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + full_width_topk, + ) + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + active_max_len, + truncated_width_topk, + ) + + reference_logits = _reference_paged_mqa_logits( + q_fp8, + kv_fp8, + kv_scale, + weights, + context_lens, + block_tables, + active_max_len, + ) + expected_topk = torch.topk(reference_logits, topk_tokens, dim=1).indices.to( + torch.int32 + ) + + torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0) + torch.testing.assert_close(truncated_width_topk, expected_topk, rtol=0, atol=0) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c35c38911a1a..6dc102373106 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -36,6 +36,8 @@ KVCacheConfig, KVCacheGroupSpec, MambaSpec, + MLAAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, ) @@ -2573,6 +2575,207 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): ) +def test_deepseek_v4_mla_prompt_cache_survives_decode_pressure(): + hash_block_size = 2 + full_block_size = 8 + swa_block_size = 2 + prompt_tokens = 35 + chunk_tokens = 4 * full_block_size + expected_hit_tokens = (prompt_tokens - 1) // full_block_size * full_block_size + + config = KVCacheConfig( + num_blocks=70, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=full_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_0"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_1"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_compressor_state"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * swa_block_size, + ), + ), + ], + ) + manager = KVCacheManager( + config, + max_model_len=128, + max_num_batched_tokens=chunk_tokens, + enable_caching=True, + hash_block_size=hash_block_size, + ) + + def run_request(request: Request, num_decode_tokens: int) -> int: + computed_blocks, num_computed_tokens = manager.get_computed_blocks(request) + computed_so_far = num_computed_tokens + remaining_prompt_tokens = request.num_prompt_tokens - num_computed_tokens + first_chunk = True + while remaining_prompt_tokens > 0: + num_new_tokens = min(chunk_tokens, remaining_prompt_tokens) + allocated = manager.allocate_slots( + request, + num_new_tokens, + num_computed_tokens if first_chunk else 0, + computed_blocks if first_chunk else None, + ) + assert allocated is not None + computed_so_far += num_new_tokens + request.num_computed_tokens = computed_so_far + remaining_prompt_tokens -= num_new_tokens + first_chunk = False + + for i in range(num_decode_tokens): + request.append_output_token_ids(10_000 + i) + allocated = manager.allocate_slots(request, 1) + assert allocated is not None + computed_so_far += 1 + request.num_computed_tokens = computed_so_far + return num_computed_tokens + + prompt_a = list(range(prompt_tokens)) + req_a = make_request("a", prompt_a, hash_block_size, sha256) + assert run_request(req_a, num_decode_tokens=0) == 0 + manager.free(req_a) + + warm_a = make_request("warm_a", prompt_a, hash_block_size, sha256) + assert run_request(warm_a, num_decode_tokens=8) == expected_hit_tokens + assert manager.get_num_common_prefix_blocks("warm_a")[0] >= ( + expected_hit_tokens // full_block_size + ) + manager.free(warm_a) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() + ) + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + req_a_again = make_request("a_again", prompt_a, hash_block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req_a_again) + assert num_computed_tokens == expected_hit_tokens + + +def test_deepseek_v4_mla_cached_prompts_do_not_block_admission(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + num_prompts = 10 + num_blocks = 80 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + for i in range(num_prompts): + prompt = list(range(i * 1000, i * 1000 + prompt_tokens)) + req = make_request(f"protected_{i}", prompt, block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.block_pool.get_num_free_blocks() < 64 + + long_req = make_request( + "long", + list(range(100_000, 100_000 + 64 * block_size)), + block_size, + sha256, + ) + assert ( + manager.allocate_slots(long_req, block_size, full_sequence_must_fit=True) + is not None + ) + + +def test_reset_prefix_cache_after_deepseek_v4_mla_prompt_cache(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + req = make_request("protected", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.reset_prefix_cache() + + def test_can_fit_full_sequence_full_attention_still_gates_oversized(): """The cap only loosens the SWA group; a prompt that exceeds the full-attention pool capacity must still be rejected.""" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f02a92681c1..91e5a8913e88 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -750,6 +750,7 @@ class CompilationConfig: "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", "vllm::deepseek_v4_attention", + "vllm::deepseek_v4_fp8_einsum", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 618084029159..6baedd3bbcbc 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -7,6 +7,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) @@ -26,6 +29,20 @@ ) +def _is_sm12x_compute_capability(compute_capability) -> bool: + if compute_capability is None: + return current_platform.is_device_capability_family(120) + + if isinstance(compute_capability, tuple): + return compute_capability[0] == 12 + + to_int = getattr(compute_capability, "to_int", None) + if callable(to_int): + return to_int() // 10 == 12 + + return int(compute_capability) // 10 == 12 + + class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( @@ -196,6 +213,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: @classmethod def is_supported(cls, compute_capability=None): + if _is_sm12x_compute_capability(compute_capability): + return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x." + if not CUTLASS_BLOCK_FP8_SUPPORTED: return ( False, @@ -219,6 +239,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): ) return True, None + def process_weights_after_loading(self, layer: torch.nn.Module): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and weight_scale is not None + and weight_scale.dtype == e8m0_dtype + ): + replace_parameter( + layer, + scale_attr_name, + _upcast_e8m0_to_fp32(weight_scale), + ) + def apply_block_scaled_mm( self, A: torch.Tensor, diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 494d61338084..fab0fb772e99 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -18,7 +18,7 @@ ReplicatedLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer -from vllm.utils.deep_gemm import fp8_einsum +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( combine_topk_swa_indices, @@ -27,6 +27,10 @@ fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, + sparse_prefill_combined_topk_size, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_fp8_einsum_config, ) from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_forward_decode_fallback, @@ -58,7 +62,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -73,6 +76,21 @@ DeepseekV4IndexerBackend, get_max_prefill_buffer_size, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_triton_sparse_mla_cudagraphs_if_enabled, + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_indexed_sparse_mla_attention_chunk, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + sparse_mla_decode_head_block_size, +) from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache from vllm.v1.attention.ops.flashmla import ( flash_mla_sparse_fwd, @@ -83,10 +101,64 @@ logger = init_logger(__name__) + +def _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu: torch.Tensor, + gather_lens_cpu: torch.Tensor, + compress_ratio: int, + swa_only: bool, +) -> tuple[int, int]: + if seq_lens_cpu.numel() == 0: + return 0, 0 + + max_gather_len = int(gather_lens_cpu.max().item()) + if swa_only: + return 0, max_gather_len + + compressed_region_size = int((seq_lens_cpu // compress_ratio).max().item()) + return compressed_region_size, compressed_region_size + max_gather_len + + +def _sparse_mla_prefill_gather_len_upper_bound( + *, + max_model_len: int, + max_num_batched_tokens: int, + window_size: int, +) -> tuple[int, int]: + max_query_chunk_tokens = max(1, min(max_model_len, max_num_batched_tokens)) + max_prefix_len = max(max_model_len - max_query_chunk_tokens, 0) + max_gather_len = max_query_chunk_tokens + min( + max_prefix_len, + max(window_size - 1, 0), + ) + return max_query_chunk_tokens, max_gather_len + + +def _allocate_deepseek_v4_wo_a_output( + num_tokens: int, + num_groups: int, + output_rank: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + shape = (num_tokens, num_groups, output_rank) + if torch.compiler.is_compiling(): + # Workspace growth can call torch.accelerator.empty_cache(), which + # Dynamo intentionally refuses to trace. During compilation this is a + # normal graph allocation, matching the o_padded allocation above. + return torch.empty(shape, dtype=dtype, device=device) + + (output,) = current_workspace_manager().get_simultaneous( + (shape, dtype), + ) + return output + + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). PREFILL_CHUNK_SIZE = 4 +_DEFAULT_SPARSE_MLA_TOPK_TOKENS = 2048 @dataclass @@ -172,6 +244,8 @@ def __init__( self.compress_ratio = compress_ratio if compress_ratio is not None else 1 self.prefix = prefix + disable_triton_sparse_mla_cudagraphs_if_enabled(mla_modules.vllm_config) + # Extract config from vllm_config config = mla_modules.vllm_config.model_config.hf_config tp_size = get_tensor_model_parallel_world_size() @@ -202,12 +276,13 @@ def __init__( self.wo_b = mla_modules.wo_b # Pick fp8_einsum recipe based on GPU arch: - # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 - # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 + # SM90/SM120: FP32 block scales stay [g, r/128, d/128]. + # SM100: INT32 packed scales become [g, r, ...]. cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" - self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - self._tma_aligned_scales = cap.major >= 10 + self._einsum_recipe, self._tma_aligned_scales = deepseek_v4_fp8_einsum_config( + cap.major + ) self.rotary_emb = mla_modules.rotary_emb self.indexer_rotary_emb = mla_modules.indexer_rotary_emb @@ -336,10 +411,12 @@ def forward( wo_a_fp8 = self.wo_a.weight wo_a_scale = self.wo_a.weight_scale_inv - z = torch.empty( - (num_tokens, self.n_local_groups, self.o_lora_rank), - device=o.device, - dtype=torch.bfloat16, + z = _allocate_deepseek_v4_wo_a_output( + num_tokens, + self.n_local_groups, + self.o_lora_rank, + torch.bfloat16, + hidden_states.device, ) torch.ops.vllm.deepseek_v4_fp8_einsum( o_fp8, @@ -494,21 +571,8 @@ def wq_b_kv_insert() -> torch.Tensor: # Handle dummy run (no metadata). if not isinstance(attn_metadata, dict): - # Reserve _forward_prefill's bf16-gather workspace; the dummy - # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. - sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) out.zero_() + self.mla_attn._reserve_prefill_workspace() return # Pad q to FlashMLA-required head count (64 or 128) @@ -585,38 +649,6 @@ def deepseek_v4_attention_fake( ) -def deepseek_v4_fp8_einsum( - a: torch.Tensor, - a_scale: torch.Tensor, - b: torch.Tensor, - b_scale: torch.Tensor, - out: torch.Tensor, - equation: str, - recipe: list[int], -) -> None: - fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) - - -def deepseek_v4_fp8_einsum_fake( - a: torch.Tensor, - a_scale: torch.Tensor, - b: torch.Tensor, - b_scale: torch.Tensor, - out: torch.Tensor, - equation: str, - recipe: list[int], -) -> None: - return None - - -direct_register_custom_op( - op_name="deepseek_v4_fp8_einsum", - op_func=deepseek_v4_fp8_einsum, - mutates_args=["out"], - fake_impl=deepseek_v4_fp8_einsum_fake, -) - - class DeepseekV4MLAAttention(nn.Module, AttentionLayerBase): # FlashMLA FP8 sparse only supports 64 or 128 heads SUPPORTED_HEAD_COUNTS = (64, 128) @@ -711,7 +743,11 @@ def __init__( assert cache_config is not None cache_config.cache_dtype = "fp8_ds_mla" kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") + logger.info_once( + "Using DeepSeek's fp8_ds_mla KV cache format. To use standard " + "fp8 kv-cache format, please set `--attention-backend " + "FLASHINFER_MLA_SPARSE`" + ) self.kv_cache_dtype = kv_cache_dtype @@ -724,6 +760,73 @@ def __init__( self.kv_cache = torch.tensor([]) + def _prefill_workspace_topk_bound(self) -> int: + if self.compress_ratio <= 1: + return 0 + if ( + self.topk_indices_buffer is not None + and self.topk_indices_buffer.ndim > 0 + and self.topk_indices_buffer.shape[-1] > 0 + ): + return int(self.topk_indices_buffer.shape[-1]) + indexer_topk = getattr(self.indexer, "topk_tokens", None) + if indexer_topk is not None: + return int(indexer_topk) + return _DEFAULT_SPARSE_MLA_TOPK_TOKENS + + def _prefill_workspace_reservation_specs( + self, + ) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]: + max_model_len = max(1, int(self.max_model_len)) + max_num_batched_tokens = max(1, int(self.max_num_batched_tokens)) + window_size = max(1, int(self.window_size)) + compress_ratio = max(1, int(self.compress_ratio)) + head_dim = int(self.head_dim) + num_heads = int(self.num_heads) + + max_query_chunk_tokens, max_gather_len = ( + _sparse_mla_prefill_gather_len_upper_bound( + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + window_size=window_size, + ) + ) + if compress_ratio <= 1: + m_bound = max_gather_len + else: + compressed_region_size = max_model_len // compress_ratio + m_bound = compressed_region_size + max_gather_len + + combined_topk = sparse_prefill_combined_topk_size( + DeepseekV4MLAAttention._prefill_workspace_topk_bound(self), + window_size, + ) + specs: list[tuple[tuple[int, ...], torch.dtype]] = [ + ((PREFILL_CHUNK_SIZE, m_bound, head_dim), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ] + if is_triton_sparse_mla_enabled_for_platform(): + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) + specs.extend( + [ + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ] + ) + return tuple(specs) + + def _reserve_prefill_workspace(self) -> None: + try: + workspace_manager = current_workspace_manager() + except AssertionError: + return + workspace_manager.get_simultaneous(*self._prefill_workspace_reservation_specs()) + def get_attn_backend(self) -> type[AttentionBackend]: return DeepseekV4FlashMLASparseBackend @@ -743,6 +846,254 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: model_version="deepseek_v4", ) + def _forward_sparse_mla_swa_decode_triton( + self, + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + + ( + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + if mtp_decode: + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + else: + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=swa_metadata.decode_swa_indices.shape[-1], + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_sparse_mla_attention_with_sink( + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_compressed_decode_triton( + self, + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + swa_k_cache: torch.Tensor, + topk_indices: torch.Tensor, + topk_lens: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata, + output: torch.Tensor, + ) -> None: + if self.compress_ratio not in (4, 128): + raise NotImplementedError( + "Triton sparse MLA compressed decode currently supports " + f"compress_ratio=4 or 128, got {self.compress_ratio}" + ) + + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + compressed_topk = topk_indices.shape[-1] + topk_chunk_size = min( + compressed_topk, + triton_sparse_mla_topk_chunk_size(), + ) + compressed_slot_ids = topk_indices[:, 0, :] + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + ( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + comp_max_score.fill_(float("-inf")) + comp_denom.zero_() + comp_acc.zero_() + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + + for chunk_start in range(0, compressed_topk, topk_chunk_size): + chunk_end = min(chunk_start + topk_chunk_size, compressed_topk) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids[:, chunk_start:chunk_end], + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=chunk_start, + scale=self.scale, + max_score=comp_max_score, + denom=comp_denom, + acc=comp_acc, + head_block_size=head_block_size, + ) + if mtp_decode: + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + else: + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_two_sparse_mla_attention_states_with_sink( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_prefill_triton( + self, + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + output: torch.Tensor, + state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min( + combined_indices.shape[-1], + triton_sparse_mla_topk_chunk_size(), + ) + query_chunk_size = min( + q.shape[0], + triton_sparse_mla_query_chunk_size(), + ) + if state_buffers is None: + ( + max_score_buffer, + denom_buffer, + output_buffer, + ) = current_workspace_manager().get_simultaneous( + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + else: + max_score_buffer, denom_buffer, output_buffer = state_buffers + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + num_tokens = token_end - token_start + max_score = max_score_buffer[:num_tokens] + denom = denom_buffer[:num_tokens] + subset_acc = output_buffer[:num_tokens] + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=self.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) + + finish_sparse_mla_attention_with_sink( + max_score, + denom, + subset_acc, + self.attn_sink, + output=output[token_start:token_end], + ) + if output.shape[1] > self.num_heads: + output[token_start:token_end, self.num_heads :].zero_() + def forward( self, q: torch.Tensor, @@ -823,12 +1174,14 @@ def _forward_decode( if self.compress_ratio == 4: # C4A: local indices differ per layer (filled by Indexer). assert self.topk_indices_buffer is not None + local_topk_indices = self.topk_indices_buffer[:num_decode_tokens] global_indices, topk_lens = compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], + local_topk_indices, swa_metadata.token_to_req_indices, attn_metadata.block_table[:num_decodes], block_size, is_valid, + global_topk_indices=local_topk_indices, ) topk_indices = global_indices.view(num_decode_tokens, 1, -1) else: @@ -867,9 +1220,35 @@ def _forward_decode( # Use unsqueeze to preserve strides (handles padded blocks correctly) swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + compressed_k_cache = kv_cache if kv_cache is not None: kv_cache = kv_cache.unsqueeze(-2) + if is_triton_sparse_mla_enabled(q.device): + if swa_only: + self._forward_sparse_mla_swa_decode_triton( + q=q, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_metadata=swa_metadata, + output=output, + ) + return + if self.compress_ratio in (4, 128): + assert compressed_k_cache is not None + assert attn_metadata is not None + assert topk_indices is not None + assert topk_lens is not None + self._forward_sparse_mla_compressed_decode_triton( + q=q, + compressed_k_cache=compressed_k_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + output=output, + ) + return # One FlashMLASchedMeta per layer type, shared across all same-type # layers within this decode step. The first forward call per type # triggers the in-kernel planner (allocating tile_scheduler_metadata @@ -932,8 +1311,12 @@ def _forward_prefill( # Use pre-computed prefill metadata. seq_lens = swa_metadata.prefill_seq_lens gather_lens = swa_metadata.prefill_gather_lens + seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu + gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu assert seq_lens is not None assert gather_lens is not None + assert seq_lens_cpu is not None + assert gather_lens_cpu is not None # Derive prefill-local token offsets from the full query_start_loc_cpu. query_start_loc_cpu = swa_metadata.query_start_loc_cpu @@ -952,24 +1335,69 @@ def _forward_prefill( assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices top_k = topk_indices.shape[-1] - # Compressed region must fit the full compressed pool (seq_len // - # compress_ratio), not just top_k. top_k bounds how many indices - # the indexer selects, not the pool size it indexes into. - N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio else: # NOTE(woosuk): topk_indices will not be used for SWA-only layers. assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 - N = 0 - M = N + self.window_size + self.max_num_batched_tokens + N, M = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=self.compress_ratio, + swa_only=swa_only, + ) num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE + max_query_chunk_tokens = 0 + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + max_query_chunk_tokens = max( + max_query_chunk_tokens, int(query_end - query_start) + ) + combined_topk = sparse_prefill_combined_topk_size(top_k, self.window_size) workspace_manager = current_workspace_manager() - kv = workspace_manager.get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - )[0] + triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) + if triton_sparse_mla_enabled: + query_chunk_size = min(q.shape[0], triton_sparse_mla_query_chunk_size()) + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + max_score_buffer, + denom_buffer, + output_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + ) + else: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ) + prefill_state_buffers = None for chunk_idx in range(num_chunks): chunk_start = chunk_idx * PREFILL_CHUNK_SIZE chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) @@ -1008,6 +1436,7 @@ def _forward_prefill( query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base ) + query_tokens = query_end - query_start combined_indices, combined_lens = combine_topk_swa_indices( topk_indices[query_start:query_end], query_start_loc[ @@ -1020,8 +1449,21 @@ def _forward_prefill( top_k, M, N, + combined_indices=combined_indices_buffer[:query_tokens], + combined_lens=combined_lens_buffer[:query_tokens], ) + if triton_sparse_mla_enabled: + self._forward_sparse_mla_prefill_triton( + q=q[query_start:query_end], + kv=kv[:chunk_size], + combined_indices=combined_indices, + combined_lens=combined_lens, + output=output[query_start:query_end], + state_buffers=prefill_state_buffers, + ) + continue + if current_platform.is_rocm(): rocm_sparse_attn_prefill( q=q[query_start:query_end], @@ -1033,16 +1475,17 @@ def _forward_prefill( attn_sink=self.attn_sink, output=output[query_start:query_end], ) - else: - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) + continue + + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 3487ac1766e6..416d871e24f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -769,6 +769,7 @@ def apply( sort_indices2=self.w2_g_idx_sort_indices, is_k_full=self.is_k_full, input_dtype=self.input_dtype, + clamp_limit=self.gemm1_clamp_limit, ) return diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 456f40bbf7a3..79edfa6f2d92 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -786,6 +786,19 @@ def update_expert_map(self): dp_size=get_dp_group().world_size, ) + @staticmethod + def _normalize_loaded_weight_for_copy( + expert_data: torch.Tensor, loaded_weight: torch.Tensor + ) -> torch.Tensor: + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and expert_data.dtype == torch.uint8 + and loaded_weight.dtype == e8m0_dtype + ): + return loaded_weight.view(torch.uint8) + return loaded_weight + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -799,10 +812,12 @@ def _load_per_tensor_weight_scale( # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight + target = param_data[expert_id][idx] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) # If we are in the row parallel case (down_proj) elif shard_id == "w2": - param_data[expert_id] = loaded_weight + target = param_data[expert_id] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) def _load_combined_w13_weight_scale( self, @@ -819,7 +834,7 @@ def _load_combined_w13_weight_scale( loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) - param.copy_(loaded_weight) + param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight)) def _load_model_weight_or_group_weight_scale( self, @@ -986,7 +1001,9 @@ def _load_w13( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_w2( self, @@ -1022,7 +1039,9 @@ def _load_w2( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index d6fef0b3d3d5..db2abc43056a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -381,6 +381,9 @@ def process_weights_after_loading(self, layer): return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -709,6 +712,9 @@ def process_weights_after_loading(self, layer): return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) + del w13, w2, w13_scale, w2_scale, w13_bias, w2_bias + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d9aab35c25f4..e2df12488652 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -817,6 +817,35 @@ def get_w8a8_block_fp8_configs( return None +def _get_default_w8a8_block_fp8_config( + M: int, + block_n: int, + block_k: int, +) -> dict[str, Any]: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and + # BLOCK_SIZE_K must be divisible by block_k. + # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the + # M-dim for single-request decode and short MTP-style draft batches. SM12x + # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes. + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", None) + if capability_major is None and capability is not None: + capability_major = capability[0] + low_m_limit = 32 if capability_major == 12 else 8 + if low_m_limit >= M: + block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3) + else: + block_m, num_stages = 64, 2 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": num_stages, + } + + def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -861,6 +890,12 @@ def w8a8_triton_block_scaled_mm( N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is not None: + if As.dtype == e8m0_dtype: + As = _upcast_e8m0_to_fp32(As) + if Bs.dtype == e8m0_dtype: + Bs = _upcast_e8m0_to_fp32(Bs) C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -870,17 +905,7 @@ def w8a8_triton_block_scaled_mm( # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - # Default config - # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] - # BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2, - } + config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1]) def grid(META): return ( @@ -1215,6 +1240,8 @@ def create_fp8_scale_parameter( if dtype == torch.float32: scale[:] = torch.finfo(torch.float32).min + elif dtype == getattr(torch, "float8_e8m0fnu", None): + scale[:] = 0 set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43f..f4285a5e5c42 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,7 +4,6 @@ import torch -import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.forward_context import get_forward_context from vllm.logger import init_logger @@ -12,7 +11,9 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( fp8_fp4_mqa_logits, + fp8_fp4_mqa_topk_indices, fp8_fp4_paged_mqa_logits, + fp8_fp4_paged_mqa_topk_indices, has_deep_gemm, ) from vllm.utils.torch_utils import ( @@ -23,6 +24,7 @@ ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, + sparse_indexer_max_logits_bytes, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager @@ -35,11 +37,58 @@ logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 +SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096 +SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288 # MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32. MXFP4_BLOCK_SIZE = 32 +def _should_use_sm120_short_row_topk_decode( + topk_tokens: int, + logits_width: int, + is_cuda_sm120: bool, +) -> bool: + if not is_cuda_sm120 or topk_tokens != 512: + return False + if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH: + return True + return logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH + + +def _use_sm120_short_row_topk_decode( + logits: torch.Tensor, + topk_tokens: int, +) -> bool: + return _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits.shape[1], + current_platform.is_cuda() + and current_platform.is_device_capability_family(120), + ) + + +def _decode_logits_width(max_model_len: int, max_seq_len: int) -> int: + if max_model_len <= 0: + return 0 + if max_seq_len <= 0: + return max_model_len + return min(max_model_len, max_seq_len) + + +def _decode_topk_logits_width( + max_model_len: int, max_seq_len: int, topk_tokens: int +) -> int: + logits_width = _decode_logits_width(max_model_len, max_seq_len) + return min(max_model_len, max(logits_width, topk_tokens)) + + +def _sparse_indexer_requires_deep_gemm() -> bool: + return current_platform.is_cuda() and not ( + current_platform.is_device_capability_family(120) + ) + + def _gather_workspace_shapes( total_seq_lens: int, head_dim: int, @@ -118,7 +167,7 @@ def sparse_attn_indexer( # Dummy allocation to simulate for peak logits tensor memory during inference. # FP8 elements so elements == bytes - max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_elems = sparse_indexer_max_logits_bytes() _ = torch.empty( max_logits_elems, dtype=torch.uint8, device=hidden_states.device ) @@ -220,6 +269,19 @@ def sparse_attn_indexer( q_slice_cast = q_slice k_quant_cast = k_quant k_scale_cast = k_scale.view(torch.float32).squeeze(-1) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + if fp8_fp4_mqa_topk_indices( + (q_slice_cast, q_scale_slice), + (k_quant_cast, k_scale_cast), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + ): + continue + logits = fp8_fp4_mqa_logits( (q_slice_cast, q_scale_slice), (k_quant_cast, k_scale_cast), @@ -230,10 +292,6 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - if current_platform.is_xpu(): xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined] logits, @@ -307,35 +365,38 @@ def sparse_attn_indexer( if use_fp4_cache else padded_q_quant_decode_tokens ) - logits = fp8_fp4_paged_mqa_logits( - (padded_q_quant_cast, padded_q_scale), - kv_cache, - weights[:num_padded_tokens], - seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - clean_logits=False, - ) - num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): - workspace_manager = current_workspace_manager() - (topk_workspace,) = workspace_manager.get_simultaneous( - ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), - ) - torch.ops._C.persistent_topk( - logits, + logits_width = _decode_topk_logits_width( + max_model_len, attn_metadata_narrowed.max_seq_len, topk_tokens + ) + logits_bytes = num_padded_tokens * logits_width * torch.float32.itemsize + used_direct_topk = False + if logits_bytes > sparse_indexer_max_logits_bytes(): + used_direct_topk = fp8_fp4_paged_mqa_topk_indices( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], seq_lens, + decode_metadata.block_table, + logits_width, topk_indices, - topk_workspace, - topk_tokens, - attn_metadata_narrowed.max_seq_len, ) - else: - if current_platform.is_xpu(): - xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + + if not used_direct_topk: + logits = fp8_fp4_paged_mqa_logits( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], + seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=False, + ) + num_rows = logits.shape[0] + + if _use_sm120_short_row_topk_decode(logits, topk_tokens): + torch.ops._C.top_k_per_row_decode( logits, next_n, seq_lens, @@ -345,17 +406,42 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) - else: - torch.ops._C.top_k_per_row_decode( + elif current_platform.is_cuda() and topk_tokens in (512, 2048): + workspace_manager = current_workspace_manager() + (topk_workspace,) = workspace_manager.get_simultaneous( + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), + ) + torch.ops._C.persistent_topk( logits, - next_n, seq_lens, topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), + topk_workspace, topk_tokens, + logits_width, ) + else: + if current_platform.is_xpu(): + xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + else: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -438,7 +524,7 @@ def __init__( self.topk_indices_buffer = topk_indices_buffer self.skip_k_cache_insert = skip_k_cache_insert self.use_fp4_cache = use_fp4_cache - if current_platform.is_cuda() and not has_deep_gemm(): + if _sparse_indexer_requires_deep_gemm() and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed." ) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index a76092028671..5fa8e7d74158 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -4,7 +4,7 @@ import glob import os import time -from collections.abc import Generator, Iterable +from collections.abc import Callable, Generator, Iterable from typing import cast import torch @@ -209,7 +209,9 @@ def _prepare_weights( return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, source: "Source" + self, + source: "Source", + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config @@ -235,6 +237,8 @@ def _get_weights_iterator( weights_iterator = fastsafetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, + local_expert_ids=self.local_expert_ids, + weight_name_filter=weight_name_filter, ) elif self.load_config.load_format == "instanttensor": weights_iterator = instanttensor_weights_iterator( @@ -256,6 +260,7 @@ def _get_weights_iterator( self.load_config.use_tqdm_on_load, self.load_config.safetensors_load_strategy, local_expert_ids=self.local_expert_ids, + weight_name_filter=weight_name_filter, ) else: if extra_config.get("enable_multithread_load"): @@ -284,6 +289,9 @@ def get_all_weights( model_config: ModelConfig, model: nn.Module, ) -> Generator[tuple[str, torch.Tensor], None, None]: + weight_name_filter = getattr(model, "skip_weight_name_before_load", None) + if not callable(weight_name_filter): + weight_name_filter = None primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, @@ -291,14 +299,14 @@ def get_all_weights( fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) - yield from self._get_weights_iterator(primary_weights) + yield from self._get_weights_iterator(primary_weights, weight_name_filter) secondary_weights = cast( Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: - yield from self._get_weights_iterator(source) + yield from self._get_weights_iterator(source, weight_name_filter) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 31b00df4e4c3..628b7bcb02db 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -875,17 +875,31 @@ def _run_prefetch() -> None: threading.Thread(target=_run_prefetch, daemon=True).start() +def _should_skip_safetensors_weight( + weight_name: str, + local_expert_ids: set[int] | None, + weight_name_filter: Callable[[str], bool] | None, +) -> bool: + if should_skip_weight(weight_name, local_expert_ids): + return True + return weight_name_filter is not None and weight_name_filter(weight_name) + + def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, safetensors_load_strategy: str | None = None, local_expert_ids: set[int] | None = None, + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files. When *local_expert_ids* is provided, expert weights not belonging to this rank are skipped **before** reading from disk, which drastically reduces storage I/O for MoE models under EP. + + When *weight_name_filter* is provided, names for which the callback returns + ``True`` are also skipped before tensor materialization. """ loading_desc = "Loading safetensors checkpoint shards" if safetensors_load_strategy == "eager": @@ -964,7 +978,9 @@ def safetensors_weights_iterator( with open(st_file, "rb") as f: state_dict = load(f.read()) for name, param in state_dict.items(): - if not should_skip_weight(name, local_expert_ids): + if not _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): yield name, param elif safetensors_load_strategy == "torchao": # we can't load flattened torchao tensor subclasses directly into the model @@ -981,7 +997,9 @@ def safetensors_weights_iterator( with safe_open(st_file, framework="pt") as f: state_dict = {} for name in f.keys(): # noqa: SIM118 - if should_skip_weight(name, local_expert_ids): + if _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): continue state_dict[name] = f.get_tensor(name) @@ -999,7 +1017,9 @@ def safetensors_weights_iterator( else: with safe_open(st_file, framework="pt") as f: for name in f.keys(): # noqa: SIM118 - if should_skip_weight(name, local_expert_ids): + if _should_skip_safetensors_weight( + name, local_expert_ids, weight_name_filter + ): continue param = f.get_tensor(name) yield name, param @@ -1087,6 +1107,8 @@ def _init_fastsafetensors_loader( def fastsafetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, + local_expert_ids: set[int] | None = None, + weight_name_filter: Callable[[str], bool] | None = None, ) -> Generator[tuple[str, torch.Tensor], None, None]: """Iterate over the weights in the model safetensor files using fastsafetensor library.""" @@ -1133,6 +1155,12 @@ def fastsafetensors_weights_iterator( try: keys = list(fb.key_to_rank_lidx.keys()) for k in keys: + if _should_skip_safetensors_weight( + k, + local_expert_ids, + weight_name_filter, + ): + continue t = fb.get_tensor(k) yield k, t finally: diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index cef4038dc2e6..c31312e86838 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1579,6 +1579,12 @@ def get_mtp_target_hidden_states(self) -> torch.Tensor | None: forward(); valid after each target step.""" return getattr(self.model, "_mtp_hidden_buffer", None) + def skip_weight_name_before_load(self, name: str) -> bool: + mapped_names = self.hf_to_vllm_mapper.apply_list([name]) + if not mapped_names: + return True + return all("mtp." in mapped_name for mapped_name in mapped_names) + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..5068ae2969df 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -338,6 +338,48 @@ def transform_sf_into_required_layout(*args, **kwargs): ) +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks.fp8_fp4_mqa_topk_indices( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices, + ) + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + def fp8_fp4_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -370,6 +412,10 @@ def fp8_fp4_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) _lazy_init() if _fp8_fp4_mqa_logits_impl is None: return _missing() @@ -404,6 +450,50 @@ def get_paged_mqa_logits_metadata( return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks.fp8_fp4_paged_mqa_topk_indices( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + topk_indices, + ) + + def fp8_fp4_paged_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv_cache: torch.Tensor, @@ -425,9 +515,10 @@ def fp8_fp4_paged_mqa_logits( [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. - kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, - D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) - storing the float dequant scale. + kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, D+4] + or [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. Within + each block, the D-byte FP8 values for every token are stored first, + followed by per-token fp32 scale bytes. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. @@ -442,6 +533,10 @@ def fp8_fp4_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) _lazy_init() if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() @@ -457,6 +552,20 @@ def fp8_fp4_paged_mqa_logits( ) +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + from vllm.v1.attention.ops.deepseek_v4_ops import sm12x_deep_gemm_fallbacks + + return sm12x_deep_gemm_fallbacks._tf32_hc_prenorm_gemm_sm12x( + x, fn, out, sqrsum, num_split + ) + + def tf32_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -471,6 +580,8 @@ def tf32_hc_prenorm_gemm( See the caller function for shape requirement """ + if current_platform.is_device_capability_family(120): + return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split) _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: return _missing() @@ -570,7 +681,9 @@ def should_use_deepgemm_for_fp8_linear( "m_grouped_fp8_fp4_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "fp8_fp4_mqa_logits", + "fp8_fp4_mqa_topk_indices", "fp8_fp4_paged_mqa_logits", + "fp8_fp4_paged_mqa_topk_indices", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 474a5b2d421e..e629efcb51df 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -30,6 +30,9 @@ SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled_for_platform, +) from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -266,6 +269,19 @@ def get_prefill_workspace_size(max_model_len: int): class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__( self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 7c0715a9e8b6..3f43b65d26fb 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass import torch -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,6 +32,28 @@ logger = init_logger(__name__) +def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: + configured_mb = os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB") + if configured_mb is not None: + return int(configured_mb) * 1024 * 1024 + + if is_sm12x is None: + is_sm12x = ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + ) + default_mb = 256 if is_sm12x else 512 + return default_mb * 1024 * 1024 + + +def _uses_deep_gemm_scheduler_metadata() -> bool: + return ( + current_platform.is_cuda() + and has_deep_gemm() + and not current_platform.is_device_capability_family(120) + ) + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -269,13 +291,12 @@ def __init__(self, *args, **kwargs): self.reorder_batch_threshold += self.num_speculative_tokens # NOTE(zyongye) fp4 indexer cache only natively supports next_n in # natively_supported_next_n_fp4; for other next_n values we fall back - # to the flattening path. Outside the SM100 datacenter family the FP8 - # paged MQA logits kernel has the same [1, 2] constraint (deepgemm - # smxx_fp8_fp4_paged_mqa_logits.hpp:233), so flatten there too. + # to the flattening path. When fp4 indexer cache is disabled, the + # native (non-flattening) path handles all next_n values. self.use_flattening = ( self.use_fp4_indexer_cache - or not current_platform.is_device_capability_family(100) - ) and next_n not in self.natively_supported_next_n_fp4 + and next_n not in self.natively_supported_next_n_fp4 + ) sm_count = num_compute_units(self.device.index) self.num_sms = sm_count @@ -520,7 +541,7 @@ def build( prefill_query_lens_cpu = torch.diff( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) - max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_bytes = sparse_indexer_max_logits_bytes() # Upper bound is exact for prefill rows (the `[num_decodes:]` # slice below). assert common_attn_metadata.seq_lens_cpu_upper_bound is not None @@ -610,8 +631,7 @@ def build( if seq_lens.dim() == 1: seq_lens = seq_lens.unsqueeze(-1) - # DeepGEMM is required for the paged MQA logits on CUDA devices - if current_platform.is_cuda() and has_deep_gemm(): + if _uses_deep_gemm_scheduler_metadata(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.storage_block_size, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py new file mode 100644 index 000000000000..ff8c3a5c4b56 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Platform controls for the portable Triton sparse MLA path.""" + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE = 512 +_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE = 256 + +logger = init_logger(__name__) + + +def _is_sm12x_device(device: torch.device) -> bool: + if not torch.cuda.is_available(): + return False + index = device.index if device.index is not None else torch.cuda.current_device() + return torch.cuda.get_device_capability(index)[0] == 12 + + +def is_triton_sparse_mla_enabled_for_platform() -> bool: + return current_platform.is_device_capability_family(120) + + +def is_triton_sparse_mla_enabled(device: torch.device) -> bool: + return _is_sm12x_device(device) + + +def disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) -> None: + if not is_triton_sparse_mla_enabled_for_platform(): + return + + from vllm.config.compilation import CompilationMode, CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.mode == CompilationMode.NONE + and compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): + return + + logger.warning_once( + "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton " + "sparse MLA path because the current Triton sparse MLA path is not " + "compile/graph-safe yet." + ) + compilation_config.mode = CompilationMode.NONE + compilation_config.compile_sizes = [] + compilation_config.compile_ranges_endpoints = [] + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.cudagraph_capture_sizes = [] + compilation_config.max_cudagraph_capture_size = 0 + + +def triton_sparse_mla_topk_chunk_size() -> int: + return _TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE + + +def triton_sparse_mla_query_chunk_size() -> int: + return _TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py new file mode 100644 index 000000000000..36c41e4b3bad --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -0,0 +1,838 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Portable sparse MLA Triton kernels.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _finish_attention_state_with_sink_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + sink_ptr, + output_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + sink = tl.load(sink_ptr + head_idx) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + subset_weight = running_denom * subset_scale + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = subset_weight + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc_values = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc_values = tl.where(has_tokens, acc_values, 0.0) + output = acc_values * subset_scale * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +@triton.jit +def _finish_two_attention_states_with_sink_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + sink_ptr, + output_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + sink = tl.load(sink_ptr + head_idx) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + has_sink = sink > -float("inf") + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink) + has_any = has0 | has1 | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + output = (acc0 * scale0 + acc1 * scale1) * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +def finish_two_sparse_mla_attention_states_with_sink( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert acc0.shape == acc1.shape + assert acc0.shape[:2] == max_score0.shape + assert output.shape[0] == acc0.shape[0] + assert output.shape[1] >= acc0.shape[1] + assert output.shape[2] == acc0.shape[2] + assert attn_sink.shape[0] >= acc0.shape[1] + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_two_attention_states_with_sink_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + attn_sink, + output, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +def finish_sparse_mla_attention_with_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape[0] == acc.shape[0] + assert output.shape[1] >= acc.shape[1] + assert output.shape[2] == acc.shape[2] + assert attn_sink.shape[0] >= acc.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_with_sink_kernel[grid]( + max_score, + denom, + acc, + attn_sink, + output, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 28564e6a97d3..bfb5f88e10b8 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -16,9 +16,14 @@ CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, +) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata from vllm.v1.kv_cache_interface import ( + AttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowMLASpec, @@ -162,6 +167,8 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None + prefill_seq_lens_cpu: torch.Tensor | None = None + prefill_gather_lens_cpu: torch.Tensor | None = None # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -195,6 +202,19 @@ class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) @@ -313,6 +333,8 @@ def build( num_prefills, seq_lens, query_start_loc, + query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu_upper_bound, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -363,6 +385,8 @@ def build_tile_scheduler( } if num_decode_tokens == 0 or current_platform.is_rocm(): return out + if is_triton_sparse_mla_enabled(self.device): + return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that # returns a fresh empty FlashMLASchedMeta; using it keeps this @@ -377,6 +401,8 @@ def _build_deepseek_v4_metadata( num_prefills: int, seq_lens: torch.Tensor, query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + seq_lens_cpu_upper_bound: torch.Tensor | None, ) -> dict[str, torch.Tensor | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. @@ -403,8 +429,27 @@ def _build_deepseek_v4_metadata( BLOCK_SIZE=triton.next_power_of_2(num_prefills), ) + assert seq_lens_cpu_upper_bound is not None + seq_lens_cpu = seq_lens_cpu_upper_bound + prefill_seq_lens_cpu = seq_lens_cpu[ + num_decodes : num_decodes + num_prefills + ] + query_lens_cpu = ( + query_start_loc_cpu[ + num_decodes + 1 : num_decodes + num_prefills + 1 + ] + - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] + ) + prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu + prefill_gather_lens_cpu = query_lens_cpu + torch.minimum( + prefix_lens_cpu, + torch.full_like(prefix_lens_cpu, self.window_size - 1), + ) + result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens + result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu + result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu return result diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 959a79f292a5..18b8d2aec7c0 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -6,6 +6,7 @@ compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, quantize_and_insert_k_cache, + sparse_prefill_combined_topk_size, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant @@ -20,4 +21,5 @@ "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", "quantize_and_insert_k_cache", + "sparse_prefill_combined_topk_size", ] diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 69d20c107e11..a85cfb01c389 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -355,6 +355,8 @@ def compute_global_topk_indices_and_lens( block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, + global_topk_indices: torch.Tensor | None = None, + topk_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -364,8 +366,20 @@ def compute_global_topk_indices_and_lens( 3. Masking padding tokens to length 0 """ num_tokens = topk_indices.shape[0] - global_topk_indices = torch.empty_like(topk_indices) - topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) + if global_topk_indices is None: + global_topk_indices = torch.empty_like(topk_indices) + else: + assert global_topk_indices.shape == topk_indices.shape + assert global_topk_indices.dtype == topk_indices.dtype + assert global_topk_indices.device == topk_indices.device + if topk_lens is None: + topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert topk_lens.shape == (num_tokens,) + assert topk_lens.dtype == torch.int32 + assert topk_lens.device == topk_indices.device _compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, global_topk_indices.stride(0), @@ -412,7 +426,7 @@ def _compute_global_topk_indices_and_lens_kernel( mask=mask, other=-1, ) - is_valid = local_idx >= 0 + is_valid = (local_idx >= 0) & is_valid_token block_indices = local_idx // block_size block_numbers = tl.load( @@ -442,6 +456,14 @@ def _compute_global_topk_indices_and_lens_kernel( _SPARSE_PREFILL_TOPK_ALIGNMENT = 128 +def sparse_prefill_combined_topk_size(topk: int, window_size: int) -> int: + return ( + (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) + // _SPARSE_PREFILL_TOPK_ALIGNMENT + * _SPARSE_PREFILL_TOPK_ALIGNMENT + ) + + def combine_topk_swa_indices( topk_indices: torch.Tensor, query_start_loc: torch.Tensor, @@ -452,23 +474,35 @@ def combine_topk_swa_indices( topk: int, M: int, N: int, + combined_indices: torch.Tensor | None = None, + combined_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = topk_indices.shape[0] num_reqs = seq_lens.shape[0] - combined_topk = ( - (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) - // _SPARSE_PREFILL_TOPK_ALIGNMENT - * _SPARSE_PREFILL_TOPK_ALIGNMENT - ) - combined_indices = torch.full( - (num_tokens, combined_topk), - fill_value=-1, - dtype=torch.int32, - device=topk_indices.device, - ) - combined_lens = torch.empty( - num_tokens, dtype=torch.int32, device=topk_indices.device - ) + combined_topk = sparse_prefill_combined_topk_size(topk, window_size) + if combined_indices is None: + combined_indices = torch.full( + (num_tokens, combined_topk), + fill_value=-1, + dtype=torch.int32, + device=topk_indices.device, + ) + else: + assert combined_indices.shape[0] >= num_tokens + assert combined_indices.shape[1] >= combined_topk + assert combined_indices.dtype == torch.int32 + assert combined_indices.device == topk_indices.device + combined_indices = combined_indices[:num_tokens, :combined_topk] + combined_indices.fill_(-1) + if combined_lens is None: + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert combined_lens.shape[0] >= num_tokens + assert combined_lens.dtype == torch.int32 + assert combined_lens.device == topk_indices.device + combined_lens = combined_lens[:num_tokens] NUM_WORKERS = 128 _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py new file mode 100644 index 000000000000..a9f52e767b20 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py @@ -0,0 +1,297 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x Triton FP8 einsum kernels for DeepSeek V4.""" + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton +from vllm.utils.deep_gemm import fp8_einsum +from vllm.utils.torch_utils import direct_register_custom_op + + +def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + exp_bits = scale.view(torch.uint8).to(torch.int32) + fp32_bits = exp_bits << 23 + return fp32_bits.view(torch.float32) + + +@triton.jit +def _deepseek_v4_sm12x_fp8_einsum_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_groups: tl.constexpr, + out_rank: tl.constexpr, + hidden_size: tl.constexpr, + a_stride_token: tl.constexpr, + a_stride_group: tl.constexpr, + a_stride_hidden: tl.constexpr, + a_scale_stride_token: tl.constexpr, + a_scale_stride_group: tl.constexpr, + a_scale_stride_hidden: tl.constexpr, + b_stride_group: tl.constexpr, + b_stride_out: tl.constexpr, + b_stride_hidden: tl.constexpr, + b_scale_stride_group: tl.constexpr, + b_scale_stride_out: tl.constexpr, + b_scale_stride_hidden: tl.constexpr, + out_stride_token: tl.constexpr, + out_stride_group: tl.constexpr, + out_stride_rank: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_OUT: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +) -> None: + token_block = tl.program_id(0) + out_block = tl.program_id(1) + group = tl.program_id(2) + + token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT) + hidden_offsets = tl.arange(0, BLOCK_HIDDEN) + accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32) + + for hidden_start in range(0, hidden_size, BLOCK_HIDDEN): + hidden = hidden_start + hidden_offsets + a = tl.load( + a_ptr + + token_offsets[:, None] * a_stride_token + + group * a_stride_group + + hidden[None, :] * a_stride_hidden, + mask=(token_offsets[:, None] < num_tokens) + & (hidden[None, :] < hidden_size), + other=0.0, + ) + b = tl.load( + b_ptr + + group * b_stride_group + + out_offsets[None, :] * b_stride_out + + hidden[:, None] * b_stride_hidden, + mask=(out_offsets[None, :] < out_rank) & (hidden[:, None] < hidden_size), + other=0.0, + ) + raw = tl.dot(a, b, out_dtype=tl.float32) + hidden_scale_block = hidden_start // BLOCK_HIDDEN + a_scale = tl.load( + a_scale_ptr + + token_offsets * a_scale_stride_token + + group * a_scale_stride_group + + hidden_scale_block * a_scale_stride_hidden, + mask=token_offsets < num_tokens, + other=0.0, + ) + b_scale = tl.load( + b_scale_ptr + + group * b_scale_stride_group + + (out_offsets // 128) * b_scale_stride_out + + hidden_scale_block * b_scale_stride_hidden, + mask=out_offsets < out_rank, + other=0.0, + ) + accum += raw * a_scale[:, None] * b_scale[None, :] + + tl.store( + out_ptr + + token_offsets[:, None] * out_stride_token + + group * out_stride_group + + out_offsets[None, :] * out_stride_rank, + accum, + mask=(token_offsets[:, None] < num_tokens) & (out_offsets[None, :] < out_rank), + ) + + +def deepseek_v4_sm12x_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x. + + ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape + ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to + ``[groups, out_rank, hidden]``. + """ + num_tokens, num_groups, hidden_size = a.shape + b_groups, out_rank, b_hidden_size = b.shape + assert b_groups == num_groups + assert b_hidden_size == hidden_size + assert out.shape == (num_tokens, num_groups, out_rank) + assert hidden_size % 128 == 0 + assert out_rank % 128 == 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if a_scale.dtype == e8m0_dtype: + a_scale = _upcast_e8m0_to_fp32(a_scale) + if b_scale.dtype == e8m0_dtype: + b_scale = _upcast_e8m0_to_fp32(b_scale) + assert a_scale.dtype == torch.float32 + assert b_scale.dtype == torch.float32 + + if num_tokens == 0: + return + + block_tokens = 16 + block_out = 128 + block_hidden = 128 + grid = ( + triton.cdiv(num_tokens, block_tokens), + triton.cdiv(out_rank, block_out), + num_groups, + ) + _deepseek_v4_sm12x_fp8_einsum_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + num_tokens, + num_groups, + out_rank, + hidden_size, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_TOKENS=block_tokens, + BLOCK_OUT=block_out, + BLOCK_HIDDEN=block_hidden, + num_warps=4, + num_stages=3, + ) + + +def deepseek_v4_fp8_einsum_config( + capability_major: int, +) -> tuple[tuple[int, int, int], bool]: + if capability_major == 10: + return (1, 1, 128), True + return (1, 128, 128), False + + +def _use_deepseek_v4_sm12x_triton_fp8_einsum( + equation: str, + recipe: list[int], + b_scale: torch.Tensor, +) -> bool: + capability = current_platform.get_device_capability() + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + return ( + capability is not None + and capability.major == 12 + and equation == "bhr,hdr->bhd" + and tuple(recipe) == (1, 128, 128) + and b_scale.dtype in (torch.float32, e8m0_dtype) + ) + + +def deepseek_v4_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + if equation == "bhr,hdr->bhd" and b.dim() == 2: + num_groups = out.shape[1] + out_rank = out.shape[2] + hidden_size = a.shape[2] + if b.shape[0] % out_rank != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight rows must be divisible by " + f"out_rank={out_rank}, got {b.shape[0]}" + ) + b_groups = b.shape[0] // out_rank + group_start = 0 + if b_groups != num_groups: + if b_groups % num_groups != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight groups must match the " + "TP-local output groups or be an integer multiple of " + f"them, got weight_groups={b_groups}, " + f"output_groups={num_groups}" + ) + group_partitions = b_groups // num_groups + group_start = ( + get_tensor_model_parallel_rank() % group_partitions + ) * num_groups + b = b.view(b_groups, out_rank, hidden_size) + if group_start != 0 or b_groups != num_groups: + b = b.narrow(0, group_start, num_groups) + + if b_scale.dim() == 2: + scale_mn = recipe[1] + scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1 + scale_k = recipe[2] * scale_k_pack + scale_out_blocks = (out_rank + scale_mn - 1) // scale_mn + scale_hidden_blocks = (hidden_size + scale_k - 1) // scale_k + if b_scale.shape[0] % scale_out_blocks != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale rows must be divisible by " + f"scale_out_blocks={scale_out_blocks}, " + f"got {b_scale.shape[0]}" + ) + scale_groups = b_scale.shape[0] // scale_out_blocks + if scale_groups not in (num_groups, b_groups): + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale groups must match the " + "TP-local output groups or weight groups, got " + f"scale_groups={scale_groups}, output_groups={num_groups}, " + f"weight_groups={b_groups}" + ) + b_scale = b_scale.view( + scale_groups, + scale_out_blocks, + scale_hidden_blocks, + ) + if scale_groups == b_groups and scale_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + elif b_scale.dim() == 3 and b_scale.shape[0] == b_groups: + if b_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + + if _use_deepseek_v4_sm12x_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, out) + return + + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) + + +def deepseek_v4_fp8_einsum_fake( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: list[int], +) -> None: + return None + + +direct_register_custom_op( + op_name="deepseek_v4_fp8_einsum", + op_func=deepseek_v4_fp8_einsum, + mutates_args=["out"], + fake_impl=deepseek_v4_fp8_einsum_fake, +) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/sm12x_deep_gemm_fallbacks.py b/vllm/v1/attention/ops/deepseek_v4_ops/sm12x_deep_gemm_fallbacks.py new file mode 100644 index 000000000000..676db9e20466 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/sm12x_deep_gemm_fallbacks.py @@ -0,0 +1,508 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x fallback implementations for DeepGEMM-only interfaces.""" + +import torch + +from vllm.platforms import current_platform + +_SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 +_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 + + +def _fp8_mqa_logits_head_chunk_size( + seq_len: int, + seq_len_kv: int, + num_heads: int, +) -> int: + # The SM120 torch path is used on long prefill paths where materializing + # [head_chunk, M, N] scores can otherwise allocate multiple GiB. Keep the + # transient score tensor bounded, while still using larger head chunks for + # short prompts where they are faster. + score_elems_per_head = max(1, seq_len * seq_len_kv) + max_heads = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_head * 4) + return max(1, min(8, num_heads, max_heads)) + + +def _fp8_mqa_logits_k_chunk_size( + seq_len: int, + seq_len_kv: int, + head_chunk_size: int, +) -> int: + score_elems_per_key = max(1, seq_len * head_chunk_size) + max_keys = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_key * 4) + return max(1, min(seq_len_kv, max_keys)) + + +def _fp8_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA logits torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + logits = torch.zeros( + (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32 + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + k_chunk_size = _fp8_mqa_logits_k_chunk_size( + seq_len, seq_len_kv, head_end - head_start + ) + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + logits[:, k_start:k_end].add_( + scores[0] if scores.shape[0] == 1 else scores.sum(dim=0) + ) + + if clean_logits: + offsets = torch.arange(seq_len_kv, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + logits = logits.masked_fill(~valid, float("-inf")) + + return logits + + +def _fp8_mqa_logits_topk_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_tokens: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA top-k torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + if out is None: + out = torch.empty( + (seq_len, topk_tokens), device=q_values.device, dtype=torch.int32 + ) + else: + assert out.shape == (seq_len, topk_tokens) + assert out.dtype == torch.int32 + out.fill_(-1) + + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + k_chunk_size = _fp8_mqa_logits_k_chunk_size(seq_len, seq_len_kv, head_chunk_size) + max_chunk_topk = min(topk_tokens, k_chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + chunk_logits = torch.zeros( + (seq_len, k_end - k_start), + device=q_values.device, + dtype=torch.float32, + ) + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + chunk_logits.add_(scores[0] if scores.shape[0] == 1 else scores.sum(dim=0)) + + offsets = torch.arange(k_start, k_end, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + chunk_logits.masked_fill_(~valid, float("-inf")) + + chunk_topk = min(topk_tokens, k_end - k_start) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return out + + +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + _fp8_mqa_logits_topk_torch( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices.shape[1], + out=topk_indices, + ) + return True + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if clean_logits and q_scale is None and q_values.dim() == 3 and kv[0].dim() == 2: + from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + fp8_mqa_logits_triton, + ) + + return fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + return _fp8_mqa_logits_torch( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + +def _fp8_paged_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 paged MQA torch path only supports FP8 Q") + + batch_size, next_n, num_heads, head_dim = q_values.shape + head_dim_with_scale = kv_cache.shape[-1] + assert head_dim_with_scale > head_dim + assert weights.shape == (batch_size * next_n, num_heads) + assert context_lens.shape == (batch_size, next_n) + + from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + _view_packed_fp8_paged_mqa_kv_cache, + ) + + kv_values, kv_scales = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_kv, _, _ = kv_values.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + + q_f32 = q_values.float() + score_bytes = _SM120_MQA_LOGITS_MAX_SCORE_BYTES + max_tokens_per_chunk = max(1, score_bytes // max(1, num_heads * 4)) + token_offsets_cache: dict[int, torch.Tensor] = {} + + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = int(context_lens[batch_idx, next_idx].item()) + if context_len <= 0: + continue + + q_row = q_f32[batch_idx, next_idx] + row_weights = weights[row] + for token_start in range(0, context_len, max_tokens_per_chunk): + token_end = min(context_len, token_start + max_tokens_per_chunk) + chunk_len = token_end - token_start + token_offsets = token_offsets_cache.get(chunk_len) + if token_offsets is None or token_offsets.device != q_values.device: + token_offsets = torch.arange( + chunk_len, device=q_values.device, dtype=torch.long + ) + token_offsets_cache[chunk_len] = token_offsets + token_ids = token_start + token_offsets + logical_blocks = token_ids // block_kv + token_in_block = token_ids - logical_blocks * block_kv + physical_blocks = block_tables[batch_idx, logical_blocks] + kv_chunk = kv_values[physical_blocks, token_in_block, 0].float() + scale_chunk = kv_scales[physical_blocks, token_in_block, 0].squeeze(-1) + kv_chunk.mul_(scale_chunk[:, None]) + scores = torch.matmul(q_row, kv_chunk.T) + scores.relu_() + scores.mul_(row_weights[:, None]) + logits[row, token_start:token_end] = scores.sum(dim=0) + + return logits + + +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if ( + q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + fp8_paged_mqa_logits_triton, + ) + + return fp8_paged_mqa_logits_triton( + q_values, kv_cache, weights, context_lens, block_tables, max_model_len + ) + return _fp8_paged_mqa_logits_torch( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + q_values, q_scale = q + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + return False + + num_rows = q_values.shape[0] * q_values.shape[1] + topk_tokens = topk_indices.shape[1] + assert topk_indices.shape == (num_rows, topk_tokens) + assert topk_indices.dtype == torch.int32 + topk_indices.fill_(-1) + if num_rows == 0 or topk_tokens == 0 or max_model_len == 0: + return True + + best_values = torch.full( + (num_rows, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + chunk_size = max(1, _SM120_PAGED_MQA_TOPK_CHUNK_SIZE) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (num_rows, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + fp8_paged_mqa_logits_triton, + ) + + for token_start in range(0, max_model_len, chunk_size): + token_count = min(chunk_size, max_model_len - token_start) + chunk_logits = fp8_paged_mqa_logits_triton( + q_values, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + chunk_topk = min(topk_tokens, token_count) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(token_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(topk_indices) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=topk_indices) + best_values, next_best_values = next_best_values, best_values + topk_indices.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + +def _tf32_hc_prenorm_gemm_torch( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + """Portable SM12x HyperConnection prenorm GEMM fallback. + + DeepGEMM's split ABI only requires that downstream consumers recover the + full result by summing over the split dimension. Keep the implementation + simple by writing the full product to split zero and clearing the rest. + """ + del num_split + product = x.float() @ fn.float().T + norm = x.float().square().sum(dim=-1) + + if out.dim() == 3: + out.zero_() + sqrsum.zero_() + out[0].copy_(product) + sqrsum[0].copy_(norm) + else: + out.copy_(product) + sqrsum.copy_(norm) + return out + + +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + if out.dim() == 3 and sqrsum.dim() == 2: + from vllm.v1.attention.ops.deepseek_v4_ops.sm12x_mqa import ( + tf32_hc_prenorm_gemm_triton, + ) + + tf32_hc_prenorm_gemm_triton(x, fn, out, sqrsum, num_split) + return out + + return _tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/sm12x_mqa.py b/vllm/v1/attention/ops/deepseek_v4_ops/sm12x_mqa.py new file mode 100644 index 000000000000..85dab1f6d9a1 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/sm12x_mqa.py @@ -0,0 +1,481 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton fallback kernels used by the local DeepSeek V4 path.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def _view_packed_fp8_paged_mqa_kv_cache( + kv_cache: torch.Tensor, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return FP8 values and fp32 scales from indexer cache block storage.""" + if kv_cache.dtype != torch.uint8: + raise TypeError(f"Expected uint8 kv_cache, got {kv_cache.dtype}") + if kv_cache.dim() == 3: + num_blocks, block_size, head_dim_with_scale = kv_cache.shape + num_kv_heads = 1 + elif kv_cache.dim() == 4: + num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape + else: + raise ValueError( + f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions" + ) + if num_kv_heads != 1: + raise ValueError(f"Expected one KV head, got {num_kv_heads}") + + scale_bytes = head_dim_with_scale - head_dim + if scale_bytes <= 0 or scale_bytes % torch.float32.itemsize != 0: + raise ValueError( + "Expected kv_cache last dimension to contain FP8 values followed " + f"by fp32 scale bytes; got head_dim={head_dim}, " + f"last_dim={head_dim_with_scale}" + ) + + block_stride = kv_cache.stride(0) + base_storage_offset = kv_cache.storage_offset() + scale_elems = scale_bytes // torch.float32.itemsize + kv_values = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, head_dim), + stride=(block_stride, head_dim, head_dim, 1), + storage_offset=base_storage_offset, + ).view(torch.float8_e4m3fn) + kv_scale = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, scale_bytes), + stride=(block_stride, scale_bytes, scale_bytes, 1), + storage_offset=base_storage_offset + block_size * head_dim, + ).view(torch.float32) + return kv_values, kv_scale[..., :scale_elems] + + +@triton.jit +def _fp8_mqa_logits_kernel( + q_ptr, + k_ptr, + scale_ptr, + weights_ptr, + cu_seqlen_ks_ptr, + cu_seqlen_ke_ptr, + logits_ptr, + num_q: tl.constexpr, + seq_len_kv: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_q + valid_n = offs_n < seq_len_kv + seq_start = tl.load(cu_seqlen_ks_ptr + offs_m, mask=valid_m, other=0) + seq_end = tl.load(cu_seqlen_ke_ptr + offs_m, mask=valid_m, other=0) + seq_mask = (offs_n[None, :] >= seq_start[:, None]) & ( + offs_n[None, :] < seq_end[:, None] + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + offs_m[:, None] * stride_qm + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + k_ptr + offs_n[:, None] * stride_kn + d[None, :] * stride_kd, + mask=valid_n[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, tl.trans(k), input_precision="tf32") + scale = tl.load(scale_ptr + offs_n, mask=valid_n, other=0.0) + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(seq_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_mqa_logits_triton( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + k_fp8, scale = kv + num_q, num_heads, head_dim = q.shape + seq_len_kv = k_fp8.shape[0] + logits = torch.empty( + (num_q, seq_len_kv), + device=q.device, + dtype=torch.float32, + ) + if num_q == 0 or seq_len_kv == 0: + return logits + + grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 64)) + _fp8_mqa_logits_kernel[grid]( + q, + k_fp8, + scale, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + num_q, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + k_fp8.stride(0), + k_fp8.stride(1), + weights.stride(0), + weights.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=8, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_rows + valid_n = offs_local_n < logits_width + batch = offs_m // next_n + q_pos = offs_m - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_m, + other=0, + ) + context_mask = valid_n[None, :] & (offs_n[None, :] < context_len[:, None]) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + + batch[:, None] * stride_btb + + block_rank[None, :] * stride_btk, + mask=valid_m[:, None] & valid_n[None, :], + other=0, + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset[None, :] * stride_ss, + mask=context_mask, + other=0.0, + ) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch[:, None] * stride_qb + + q_pos[:, None] * stride_qn + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[:, :, None] * stride_kvb + + block_offset[None, :, None] * stride_kvs + + d[None, None, :] * stride_kvd, + mask=context_mask[:, :, None] & (d[None, None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.sum(q[:, None, :] * k, axis=2) + weighted = tl.maximum(scores * scale, 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(context_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_local_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_paged_mqa_logits_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + grid = (triton.cdiv(num_rows, 4), triton.cdiv(token_count, 64)) + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=4, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _tf32_hc_prenorm_gemm_kernel( + x_ptr, + fn_ptr, + out_ptr, + sqrsum_ptr, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_fnn: tl.constexpr, + stride_fnk: tl.constexpr, + stride_outs: tl.constexpr, + stride_outm: tl.constexpr, + stride_outn: tl.constexpr, + stride_sqs: tl.constexpr, + stride_sqm: tl.constexpr, + NUM_SPLIT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_s = tl.program_id(2) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + split_k = tl.cdiv(K, NUM_SPLIT) + split_begin = pid_s * split_k + split_end = tl.minimum(split_begin + split_k, K) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + sq = tl.zeros((BLOCK_M,), dtype=tl.float32) + + for k0 in tl.range(0, split_k, BLOCK_K): + k = split_begin + k0 + offs_k + k_mask = k < split_end + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & k_mask[None, :], + other=0.0, + ).to(tl.float32) + fn = tl.load( + fn_ptr + offs_n[None, :] * stride_fnn + k[:, None] * stride_fnk, + mask=(offs_n[None, :] < N) & k_mask[:, None], + other=0.0, + ).to(tl.float32) + + acc += tl.dot(x, fn, input_precision="tf32", out_dtype=tl.float32) + sq += tl.sum(x * x, axis=1) + + tl.store( + out_ptr + + pid_s * stride_outs + + offs_m[:, None] * stride_outm + + offs_n[None, :] * stride_outn, + acc, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), + ) + + if pid_n == 0: + tl.store( + sqrsum_ptr + pid_s * stride_sqs + offs_m * stride_sqm, + sq, + mask=offs_m < M, + ) + + +def tf32_hc_prenorm_gemm_triton( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> None: + assert x.dim() == 2 + assert fn.dim() == 2 + assert out.dim() == 3 + assert sqrsum.dim() == 2 + + m, k = x.shape + n = fn.shape[0] + assert fn.shape[1] == k + assert out.shape == (num_split, m, n) + assert sqrsum.shape == (num_split, m) + + if m == 0: + return + + block_m = 16 + block_n = triton.next_power_of_2(n) + block_n = min(max(block_n, 16), 32) + block_k = 64 + grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n), num_split) + _tf32_hc_prenorm_gemm_kernel[grid]( + x, + fn, + out, + sqrsum, + m, + k, + n, + x.stride(0), + x.stride(1), + fn.stride(0), + fn.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + sqrsum.stride(0), + sqrsum.stride(1), + num_split, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=4, + ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 65993e804153..4cb7bd57b8aa 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -250,6 +250,21 @@ def remove_skipped_blocks( for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, total_computed_tokens) + def release_protected_prompt_blocks( + self, + target_free_blocks: int | None = None, + block_ids_to_skip: set[int] | None = None, + ) -> None: + for manager in self.single_type_managers: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + manager.release_protected_prompt_blocks( + target_free_blocks, block_ids_to_skip + ) + def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ Get the blocks for the request. @@ -475,6 +490,8 @@ def verify_and_split_kv_cache_groups(self) -> None: # block cache hit yet. block_sizes = [spec.block_size for spec, _, _ in attention_groups] self.lcm_block_size = lcm(*block_sizes) + for manager in self.single_type_managers: + manager.cache_alignment_tokens = self.lcm_block_size # Attention-group indices (into ``self.attention_groups``) that # contain at least one EAGLE/MTP KV cache group. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 431776870cf4..5470029b3e0c 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -331,6 +331,9 @@ def allocate_slots( num_local_computed_tokens + num_external_computed_tokens, self.max_model_len, ) + block_ids_to_skip_releasing = self._block_ids_to_skip_releasing( + new_computed_block_list + ) if full_sequence_must_fit: # First check and fail if the full request sequence won't fit. @@ -345,7 +348,9 @@ def allocate_slots( num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks( + num_blocks_to_allocate, block_ids_to_skip_releasing + ): return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -373,7 +378,9 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks( + num_blocks_to_allocate, block_ids_to_skip_releasing + ): # Cannot allocate new blocks return None @@ -446,6 +453,29 @@ def evict_blocks(self, block_ids: set[int]) -> None: """ self.block_pool.evict_blocks(block_ids) + @staticmethod + def _block_ids_to_skip_releasing( + blocks: tuple[Sequence[KVCacheBlock], ...], + ) -> set[int]: + return { + block.block_id + for group_blocks in blocks + for block in group_blocks + if not block.is_null + } + + def _has_enough_free_blocks( + self, + num_blocks: int, + block_ids_to_skip_releasing: set[int] | None = None, + ) -> bool: + if num_blocks <= self.block_pool.get_num_free_blocks(): + return True + self.coordinator.release_protected_prompt_blocks( + num_blocks, block_ids_to_skip_releasing + ) + return num_blocks <= self.block_pool.get_num_free_blocks() + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, @@ -455,6 +485,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ + self.coordinator.release_protected_prompt_blocks() if not self.block_pool.reset_prefix_cache(): return False if self.log_stats: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8d3a6f75688..1dcca597574b 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Sequence from vllm.utils.math_utils import cdiv @@ -42,6 +42,7 @@ def __init__( dcp_world_size: int = 1, pcp_world_size: int = 1, max_admission_blocks_per_request: int | None = None, + max_model_len: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -65,6 +66,8 @@ def __init__( self.block_pool = block_pool self.enable_caching = enable_caching self._max_admission_blocks_per_request = max_admission_blocks_per_request + self.max_model_len = max_model_len + self.cache_alignment_tokens = self.block_size self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -80,6 +83,8 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + self._protected_prompt_block_ids: set[int] = set() + self._protected_prompt_block_queue: deque[int] = deque() @classmethod def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): @@ -274,6 +279,79 @@ def take_new_block_ids(self) -> list[int]: self.new_block_ids = [] return ids + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_model_len is None: + return None + return 2 * cdiv(max(1, self.max_model_len), self.block_size) + + def _protect_prompt_blocks(self, blocks: Sequence[KVCacheBlock]) -> None: + if not self.enable_caching: + return + + protected: list[KVCacheBlock] = [] + for block in blocks: + if ( + block.is_null + or block.block_hash is None + or block.block_id in self._protected_prompt_block_ids + ): + continue + protected.append(block) + self._protected_prompt_block_ids.add(block.block_id) + self._protected_prompt_block_queue.append(block.block_id) + + if not protected: + return + + # Keep an extra reference for prompt blocks that must survive after + # their request releases its normal runtime reference. Later request + # reuse increments/decrements the runtime reference as usual. + self.block_pool.touch(protected) + self._trim_protected_prompt_blocks() + + def _trim_protected_prompt_blocks(self) -> None: + max_blocks = self._max_protected_prompt_blocks() + if max_blocks is None: + return + + while len(self._protected_prompt_block_ids) > max_blocks: + if not self._release_one_protected_prompt_block(): + return + + def _release_one_protected_prompt_block( + self, block_ids_to_skip: set[int] | None = None + ) -> bool: + attempts = len(self._protected_prompt_block_queue) + while attempts: + block_id = self._protected_prompt_block_queue.popleft() + attempts -= 1 + if block_id not in self._protected_prompt_block_ids: + continue + if block_ids_to_skip is not None and block_id in block_ids_to_skip: + self._protected_prompt_block_queue.append(block_id) + continue + + self._protected_prompt_block_ids.remove(block_id) + block = self.block_pool.blocks[block_id] + if block.ref_cnt > 0: + self.block_pool.free_blocks([block]) + return True + return False + + def release_protected_prompt_blocks( + self, + target_free_blocks: int | None = None, + block_ids_to_skip: set[int] | None = None, + ) -> None: + while self._protected_prompt_block_ids: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + if not self._release_one_protected_prompt_block(block_ids_to_skip): + return + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. @@ -504,6 +582,54 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return num_common_blocks +class MLAAttentionManager(FullAttentionManager): + """KV cache manager for DeepSeek V4 compressed MLA cache.""" + + def _should_protect_prompt_blocks(self) -> bool: + return ( + self.kv_cache_spec.model_version == "deepseek_v4" + or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla" + or self.kv_cache_spec.compress_ratio > 1 + ) + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + super().cache_blocks(request, num_tokens) + if ( + not self._should_protect_prompt_blocks() + or num_tokens < request.num_prompt_tokens + or request.num_prompt_tokens <= 1 + ): + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + num_hit_blocks = aligned_cache_hit_length // self.block_size + if num_hit_blocks == 0: + return + + self._protect_prompt_blocks( + self.req_to_blocks[request.request_id][:num_hit_blocks] + ) + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + blocks = self.req_to_blocks[running_request_id] + num_common_blocks = 0 + expected_ref_cnt = len(self.req_to_blocks) + for block in blocks: + ref_cnt = block.ref_cnt + if block.block_id in self._protected_prompt_block_ids: + ref_cnt -= 1 + if ref_cnt == expected_ref_cnt: + num_common_blocks += 1 + else: + break + return num_common_blocks + + class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -641,6 +767,42 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return 0 +class SlidingWindowMLAManager(SlidingWindowManager): + """KV cache manager for DeepSeek V4's sliding-window MLA cache. + + During decode, the live sliding window can move past the prompt boundary. + The blocks around the hybrid-aligned prompt boundary are still the suffix + needed for a future prefix-cache hit of the same prompt. + """ + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + super().cache_blocks(request, num_tokens) + if not self.enable_caching or num_tokens < request.num_prompt_tokens: + return + if request.num_prompt_tokens <= 1: + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + if aligned_cache_hit_length <= 0: + return + + aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size + last_full_prompt_block = max_cache_hit_length // self.block_size + contiguous_blocks = cdiv(self.sliding_window - 1, self.block_size) + first_protected_block = max(0, aligned_num_hit_blocks - contiguous_blocks) + last_protected_block = max(aligned_num_hit_blocks, last_full_prompt_block) + blocks = self.req_to_blocks[request.request_id] + protected_blocks = blocks[ + first_protected_block : min(last_protected_block, len(blocks)) + ] + self._protect_prompt_blocks(protected_blocks) + + class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -1124,6 +1286,7 @@ def __init__( kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_model_len: int | None = None, ): super().__init__( kv_cache_spec, @@ -1132,6 +1295,7 @@ def __init__( kv_cache_group_id, dcp_world_size, pcp_world_size, + max_model_len=max_model_len, ) sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 @@ -1142,9 +1306,9 @@ def __init__( spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, TQFullAttentionSpec: FullAttentionManager, - MLAAttentionSpec: FullAttentionManager, + MLAAttentionSpec: MLAAttentionManager, SlidingWindowSpec: SlidingWindowManager, - SlidingWindowMLASpec: SlidingWindowManager, + SlidingWindowMLASpec: SlidingWindowMLAManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, @@ -1159,6 +1323,7 @@ def get_manager_for_kv_cache_spec( **kwargs, ) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] + kwargs["max_model_len"] = max_model_len # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method).