diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py
index 0cea46d6284f..01f030836527 100644
--- a/tests/kernels/attention/test_deepgemm_attention.py
+++ b/tests/kernels/attention/test_deepgemm_attention.py
@@ -10,10 +10,12 @@
_ceil_to_ue8m0,
calc_diff,
fp8_fp4_mqa_logits,
- fp8_fp4_paged_mqa_logits,
get_num_sms,
get_paged_mqa_logits_metadata,
)
+from vllm.utils.deep_gemm import (
+ fp8_fp4_paged_mqa_logits as fp8_paged_mqa_logits,
+)
from vllm.utils.import_utils import has_deep_gemm
from vllm.utils.math_utils import cdiv
@@ -90,10 +92,64 @@ def _ref_fp8_mqa_logits(
return logits
+def _supports_deepgemm_optimized_mqa_logits() -> bool:
+ return current_platform.is_cuda() and (
+ current_platform.is_device_capability(90)
+ or current_platform.is_device_capability_family(100)
+ )
+
+
+@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_mqa_logits_torch_path():
+ torch.manual_seed(0)
+
+ seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32
+ q = torch.randn(
+ seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16
+ )
+ kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16)
+ weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32)
+ cu_seqlen_ks = (torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3)
+ cu_seqlen_ke = torch.minimum(
+ torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4,
+ torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32),
+ )
+
+ q_fp8 = q.to(torch.float8_e4m3fn)
+ kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4)
+ kv_scale = (kv_amax / 448.0).squeeze(1).contiguous()
+ kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn)
+
+ logits = fp8_fp4_mqa_logits(
+ (q_fp8, None),
+ (kv_fp8, kv_scale),
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ clean_logits=True,
+ )
+
+ kv_dequant = kv_fp8.float() * kv_scale[:, None]
+ score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant)
+ ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0)
+ offsets = torch.arange(seq_len_kv, device="cuda")
+ valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & (
+ offsets[None, :] < cu_seqlen_ke[:, None]
+ )
+ ref_logits = ref_logits.masked_fill(~valid, float("-inf"))
+
+ assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits))
+ finite = torch.isfinite(ref_logits)
+ assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4
+
+
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
- not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
+ not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only"
)
@pytest.mark.parametrize("clean_logits", [True, False])
def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
@@ -150,7 +206,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool):
assert diff < 1e-3, f"{diff=}"
-def _ref_fp8_fp4_paged_mqa_logits(
+def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
@@ -203,12 +259,10 @@ def _ref_fp8_fp4_paged_mqa_logits(
@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only")
@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available")
@pytest.mark.skipif(
- not current_platform.has_device_capability(90), reason="SM90 and SM100 only"
+ not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only"
)
-def test_deepgemm_fp8_fp4_paged_mqa_logits():
- # NOTE: clean_logits=True is incompatible with the 2D context_lens
- # required by csrc/apis/attention.hpp; only the False path is exercised.
- clean_logits = False
+@pytest.mark.parametrize("clean_logits", [True, False])
+def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool):
torch.manual_seed(0)
random.seed(0)
@@ -260,29 +314,24 @@ def test_deepgemm_fp8_fp4_paged_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
- # deep_gemm paged MQA logits requires 2D context_lens of
- # shape (B, next_n) (csrc/apis/attention.hpp:332-335);
- # see indexer.py:607-608. For each batch/next_n token, the
- # effective context length is context_lens[b] - next_n + j + 1.
- next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32)
- context_lens_2d = (
- context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange
- ).contiguous()
+ deepgemm_context_lens = (
+ context_lens[:, None].expand(-1, next_n).contiguous()
+ )
schedule_metadata = get_paged_mqa_logits_metadata(
- context_lens_2d, blocksize, get_num_sms()
+ deepgemm_context_lens, blocksize, get_num_sms()
)
- logits = fp8_fp4_paged_mqa_logits(
+ logits = fp8_paged_mqa_logits(
(q_fp8, None),
kv_cache_fp8,
weights,
- context_lens_2d,
+ deepgemm_context_lens,
block_tables,
schedule_metadata,
max_model_len,
clean_logits=clean_logits,
)
- ref_logits = _ref_fp8_fp4_paged_mqa_logits(
+ ref_logits = _ref_fp8_paged_mqa_logits(
q,
kv_cache,
weights,
diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py
index ebc3256b548f..a5d2c33b47eb 100644
--- a/tests/kernels/moe/test_moe.py
+++ b/tests/kernels/moe/test_moe.py
@@ -29,10 +29,15 @@
)
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
+ FusedMoEConfig,
+ FusedMoEParallelConfig,
+ RoutingMethodType,
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
+ mxfp4_w4a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
+ MarlinExperts,
batched_fused_marlin_moe,
fused_marlin_moe,
)
@@ -1007,6 +1012,61 @@ def test_fused_marlin_moe(
torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0)
+def test_marlin_experts_apply_forwards_gemm1_clamp_limit(monkeypatch):
+ captured: dict[str, float | None] = {}
+
+ def fake_fused_marlin_moe(**kwargs):
+ captured["clamp_limit"] = kwargs.get("clamp_limit")
+ kwargs["output"].zero_()
+
+ monkeypatch.setattr(
+ "vllm.model_executor.layers.fused_moe.fused_marlin_moe."
+ "fused_marlin_moe",
+ fake_fused_marlin_moe,
+ )
+
+ clamp_limit = 10.0
+ moe_config = FusedMoEConfig(
+ num_experts=2,
+ experts_per_token=1,
+ hidden_dim=16,
+ intermediate_size_per_partition=8,
+ num_local_experts=2,
+ num_logical_experts=2,
+ activation=MoEActivation.SILU,
+ device="cpu",
+ routing_method=RoutingMethodType.Default,
+ moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
+ in_dtype=torch.bfloat16,
+ )
+ quant_config = mxfp4_w4a16_moe_quant_config(
+ w1_scale=torch.empty(0),
+ w2_scale=torch.empty(0),
+ gemm1_clamp_limit=clamp_limit,
+ )
+ experts = MarlinExperts(moe_config=moe_config, quant_config=quant_config)
+
+ experts.apply(
+ output=torch.empty((1, 16), dtype=torch.bfloat16),
+ hidden_states=torch.empty((1, 16), dtype=torch.bfloat16),
+ w1=torch.empty((2, 1, 1), dtype=torch.int32),
+ w2=torch.empty((2, 1, 1), dtype=torch.int32),
+ topk_weights=torch.ones((1, 1), dtype=torch.float32),
+ topk_ids=torch.zeros((1, 1), dtype=torch.int32),
+ activation=MoEActivation.SILU,
+ global_num_experts=2,
+ expert_map=None,
+ a1q_scale=None,
+ a2_scale=None,
+ workspace13=torch.empty(0),
+ workspace2=torch.empty(0),
+ expert_tokens_meta=None,
+ apply_router_weight_on_input=False,
+ )
+
+ assert captured["clamp_limit"] == clamp_limit
+
+
@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
@pytest.mark.parametrize("m", [1, 256])
diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py
index 304f044868a3..7da906f58446 100644
--- a/tests/models/test_deepseek_v4_mega_moe.py
+++ b/tests/models/test_deepseek_v4_mega_moe.py
@@ -9,6 +9,7 @@
from vllm.model_executor.models.deepseek_v4 import (
DeepseekV4MegaMoEExperts,
_stage_deepseek_v4_mega_moe_inputs,
+ _use_deepseek_v4_mega_moe,
make_deepseek_v4_expert_params_mapping,
)
from vllm.platforms import current_platform
@@ -19,6 +20,52 @@
)
+def _make_mega_moe_config(
+ *,
+ enable_expert_parallel: bool = True,
+ moe_backend: str = "auto",
+):
+ return SimpleNamespace(
+ parallel_config=SimpleNamespace(
+ enable_expert_parallel=enable_expert_parallel
+ ),
+ kernel_config=SimpleNamespace(moe_backend=moe_backend),
+ )
+
+
+def test_deepseek_v4_mega_moe_selection_preserves_kernel_config(monkeypatch):
+ from vllm import envs
+
+ monkeypatch.delenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", raising=False)
+ envs.disable_envs_cache()
+
+ assert _use_deepseek_v4_mega_moe(
+ _make_mega_moe_config(moe_backend="deep_gemm_mega_moe")
+ )
+ assert not _use_deepseek_v4_mega_moe(_make_mega_moe_config())
+ with pytest.raises(NotImplementedError, match="requires expert parallel"):
+ _use_deepseek_v4_mega_moe(
+ _make_mega_moe_config(
+ enable_expert_parallel=False,
+ moe_backend="deep_gemm_mega_moe",
+ )
+ )
+
+
+def test_deepseek_v4_mega_moe_selection_env_override(monkeypatch):
+ from vllm import envs
+
+ monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "1")
+ envs.disable_envs_cache()
+ assert _use_deepseek_v4_mega_moe(_make_mega_moe_config())
+
+ monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0")
+ envs.disable_envs_cache()
+ assert not _use_deepseek_v4_mega_moe(
+ _make_mega_moe_config(moe_backend="deep_gemm_mega_moe")
+ )
+
+
def test_deepseek_v4_mega_moe_expert_mapping():
mapping = make_deepseek_v4_expert_params_mapping(2)
@@ -46,7 +93,8 @@ def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float():
def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
vllm_config = SimpleNamespace(
- scheduler_config=SimpleNamespace(max_num_batched_tokens=4)
+ scheduler_config=SimpleNamespace(max_num_batched_tokens=4),
+ compilation_config=SimpleNamespace(static_forward_context={}),
)
experts = DeepseekV4MegaMoEExperts(
vllm_config,
@@ -111,7 +159,10 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership():
reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.",
)
def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact():
- from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8
+ per_token_cast_to_fp8 = pytest.importorskip(
+ "deep_gemm.utils",
+ reason="DeepGEMM helper package is required for FP8 staging parity.",
+ ).per_token_cast_to_fp8
device = torch.device("cuda")
num_tokens = 7
diff --git a/tests/models/test_deepseek_v4_pp.py b/tests/models/test_deepseek_v4_pp.py
new file mode 100644
index 000000000000..7c0ae5dfd725
--- /dev/null
+++ b/tests/models/test_deepseek_v4_pp.py
@@ -0,0 +1,9 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+from vllm.model_executor.models.deepseek_v4 import DeepseekV4ForCausalLM
+from vllm.model_executor.models.interfaces import supports_pp
+
+
+def test_deepseek_v4_declares_pipeline_parallel_support():
+ assert supports_pp(DeepseekV4ForCausalLM)
diff --git a/tests/quantization/test_fp8_scale_parameter.py b/tests/quantization/test_fp8_scale_parameter.py
new file mode 100644
index 000000000000..c95cdbb98be2
--- /dev/null
+++ b/tests/quantization/test_fp8_scale_parameter.py
@@ -0,0 +1,33 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+import torch
+
+import vllm.model_executor.parameter as parameter
+from vllm.model_executor.layers.quantization.utils.fp8_utils import (
+ create_fp8_scale_parameter,
+)
+from vllm.model_executor.parameter import BlockQuantScaleParameter
+
+
+@pytest.mark.skipif(
+ not hasattr(torch, "float8_e8m0fnu"),
+ reason="torch does not expose float8_e8m0fnu",
+)
+def test_create_fp8_scale_parameter_initializes_e8m0(monkeypatch):
+ monkeypatch.setattr(parameter, "get_tensor_model_parallel_rank", lambda: 0)
+ monkeypatch.setattr(parameter, "get_tensor_model_parallel_world_size", lambda: 1)
+
+ scale = create_fp8_scale_parameter(
+ BlockQuantScaleParameter,
+ output_partition_sizes=[128],
+ input_size_per_partition=128,
+ block_size=[128, 128],
+ weight_loader=None,
+ scale_dtype=torch.float8_e8m0fnu,
+ )
+
+ assert scale.dtype == torch.float8_e8m0fnu
+ raw_scale = scale.data.view(torch.uint8)
+ assert torch.equal(raw_scale, torch.zeros_like(raw_scale))
diff --git a/tests/quantization/test_mxfp4.py b/tests/quantization/test_mxfp4.py
new file mode 100644
index 000000000000..6e6792e60cae
--- /dev/null
+++ b/tests/quantization/test_mxfp4.py
@@ -0,0 +1,38 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+def test_mxfp4_e8m0_scale_loading_preserves_raw_bytes():
+ from types import SimpleNamespace
+
+ import pytest
+ import torch
+
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
+
+ e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
+ if e8m0_dtype is None:
+ pytest.skip("torch does not expose float8_e8m0fnu")
+
+ layer = object.__new__(FusedMoE)
+ layer.moe_config = SimpleNamespace(is_act_and_mul=True)
+
+ expert_data = torch.zeros((4, 2), dtype=torch.uint8)
+ loaded_scale = torch.tensor(
+ [[0.0078125, 0.015625], [0.5, 1.0]],
+ dtype=e8m0_dtype,
+ )
+
+ layer._load_w13(
+ expert_data=expert_data,
+ shard_dim=0,
+ shard_id="w1",
+ loaded_weight=loaded_scale,
+ tp_rank=0,
+ )
+
+ torch.testing.assert_close(
+ expert_data[:2],
+ loaded_scale.view(torch.uint8),
+ rtol=0,
+ atol=0,
+ )
diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py
index f5b37194f927..a013cf1a8775 100644
--- a/tests/reasoning/test_deepseekv3_reasoning_parser.py
+++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py
@@ -4,16 +4,35 @@
import pytest
from transformers import AutoTokenizer
+from vllm.config.reasoning import ReasoningConfig
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning import ReasoningParserManager
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
-from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
+from vllm.reasoning.deepseek_v3_reasoning_parser import (
+ DeepSeekV3ReasoningParser,
+ DeepSeekV3ReasoningWithThinkingParser,
+)
from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1"
+class FakeReasoningTokenizer:
+
+ def get_vocab(self) -> dict[str, int]:
+ return {"": 100, "": 101}
+
+ def encode(
+ self,
+ text: str,
+ add_special_tokens: bool = False,
+ **kwargs,
+ ) -> list[int]:
+ assert add_special_tokens is False
+ return [self.get_vocab()[text]]
+
+
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
@@ -37,7 +56,22 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type):
def test_deepseek_v4_reasoning_parser_alias():
parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4")
- assert parser_cls is DeepSeekV3ReasoningParser
+ assert parser_cls is DeepSeekV3ReasoningWithThinkingParser
+
+
+def test_deepseek_v4_auto_reasoning_config_initializes_budget_tokens(monkeypatch):
+ monkeypatch.setattr(
+ "vllm.config.reasoning.cached_tokenizer_from_config",
+ lambda model_config: FakeReasoningTokenizer(),
+ )
+
+ config = ReasoningConfig(reasoning_parser="deepseek_v4")
+
+ config.initialize_token_ids(model_config=None)
+
+ assert config.enabled is True
+ assert config.reasoning_start_token_ids == [100]
+ assert config.reasoning_end_token_ids == [101]
def test_identity_reasoning_parser_basic(tokenizer):
diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py
index 358732eabf40..0a3253957add 100644
--- a/tests/tokenizers_/test_deepseek_v4.py
+++ b/tests/tokenizers_/test_deepseek_v4.py
@@ -8,8 +8,14 @@
import pytest
from vllm.entrypoints.chat_utils import parse_chat_messages
+from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ ChatMessage,
+)
+from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.renderers.registry import RENDERER_REGISTRY
from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer
+from vllm.tokenizers.deepseek_v4_encoding import encode_arguments_to_dsml
from vllm.tokenizers.registry import TokenizerRegistry
FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4"
@@ -96,6 +102,130 @@ def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs):
assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>")
+def test_deepseek_v4_honors_official_thinking_request_field():
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": "deepseek-ai/DeepSeek-V4-Flash",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "thinking": {"type": "enabled"},
+ }
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs
+ )
+
+ prompt = _tokenizer().apply_chat_template(
+ request.messages,
+ tokenize=False,
+ **chat_kwargs,
+ )
+
+ assert chat_kwargs["thinking"] is True
+ assert chat_kwargs["enable_thinking"] is True
+ assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>")
+
+
+def test_deepseek_v4_defaults_to_official_thinking_for_openai_request():
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": "deepseek-ai/DeepSeek-V4-Flash",
+ "messages": [{"role": "user", "content": "Hello"}],
+ }
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs
+ )
+
+ assert chat_kwargs["thinking"] is True
+ assert chat_kwargs["enable_thinking"] is True
+
+
+def test_deepseek_v4_preserves_official_reasoning_content_alias():
+ messages = [
+ {"role": "user", "content": "Q1"},
+ {"role": "assistant", "reasoning_content": "because", "content": "A1"},
+ {"role": "user", "content": "Q2"},
+ ]
+
+ conversation, _, _ = parse_chat_messages(
+ messages,
+ _model_config(),
+ content_format="string",
+ )
+
+ assert conversation[1]["reasoning"] == "because"
+ assert conversation[1]["reasoning_content"] == "because"
+
+
+def test_deepseek_v4_response_messages_expose_reasoning_content_alias():
+ message = ChatMessage(role="assistant", reasoning="because", content="answer")
+ delta = DeltaMessage(reasoning="because")
+
+ assert message.reasoning_content == "because"
+ assert delta.reasoning_content == "because"
+ assert (
+ ChatMessage(
+ role="assistant",
+ reasoning_content="because",
+ content="answer",
+ ).reasoning
+ == "because"
+ )
+
+
+def test_deepseek_v4_preserves_official_prefix_assistant_message():
+ messages = [
+ {"role": "user", "content": "Please write quick sort code"},
+ {"role": "assistant", "content": "```python\n", "prefix": True},
+ ]
+
+ conversation, _, _ = parse_chat_messages(
+ messages,
+ _model_config(),
+ content_format="string",
+ )
+ prompt = _tokenizer().apply_chat_template(
+ conversation=conversation,
+ messages=messages,
+ tokenize=False,
+ )
+
+ assert conversation[1]["prefix"] is True
+ assert conversation[1]["wo_eos"] is True
+ assert prompt.endswith("<|Assistant|>```python\n")
+ assert not prompt.endswith("<|end▁of▁sentence|>")
+
+
+def test_deepseek_v4_thinking_ignores_sampling_controls():
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": "deepseek-ai/DeepSeek-V4-Flash",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "thinking": {"type": "enabled"},
+ "temperature": 0.2,
+ "top_p": 0.3,
+ "top_k": 4,
+ "presence_penalty": 1.5,
+ "frequency_penalty": 1.25,
+ }
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs
+ )
+
+ sampling_params = request.to_sampling_params(
+ 16,
+ {},
+ chat_template_kwargs=chat_kwargs,
+ )
+
+ assert sampling_params.temperature == 1.0
+ assert sampling_params.top_p == 1.0
+ assert sampling_params.top_k == 0
+ assert sampling_params.presence_penalty == 0.0
+ assert sampling_params.frequency_penalty == 0.0
+
+
def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools():
tools = [
{
@@ -183,6 +313,66 @@ def test_deepseek_v4_renders_parsed_history_tool_arguments():
assert 'parameter name="arguments"' not in prompt
+@pytest.mark.parametrize(
+ ("tool_call", "expected_parameter"),
+ [
+ ({"name": "refresh", "arguments": None}, None),
+ ({"name": "refresh"}, None),
+ ({"name": "refresh", "arguments": ""}, None),
+ (
+ {"name": "refresh", "arguments": '{"target": "cache"}'},
+ '<|DSML|parameter name="target" string="true">cache',
+ ),
+ (
+ {"name": "refresh", "arguments": {"target": "cache"}},
+ '<|DSML|parameter name="target" string="true">cache',
+ ),
+ ],
+)
+def test_deepseek_v4_encodes_empty_history_tool_arguments(
+ tool_call, expected_parameter
+):
+ prompt = encode_arguments_to_dsml(tool_call)
+
+ if expected_parameter is None:
+ assert prompt == ""
+ else:
+ assert expected_parameter in prompt
+
+
+def test_deepseek_v4_renders_openai_history_tool_call_with_null_arguments():
+ messages = [
+ {"role": "user", "content": "Refresh state"},
+ {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "id": "call_1",
+ "type": "function",
+ "function": {
+ "name": "refresh",
+ "arguments": None,
+ },
+ }
+ ],
+ },
+ ]
+ conversation, _, _ = parse_chat_messages(
+ messages,
+ _model_config(),
+ content_format="string",
+ )
+
+ prompt = _tokenizer().apply_chat_template(
+ conversation=conversation,
+ messages=messages,
+ tokenize=False,
+ )
+
+ assert '<|DSML|invoke name="refresh">' in prompt
+ assert "<|DSML|parameter" not in prompt
+
+
@pytest.mark.parametrize("reasoning_effort", ["minimal", "low", "medium", "high"])
def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort):
prompt = _tokenizer().apply_chat_template(
@@ -288,3 +478,136 @@ def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs):
expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text()
assert prompt == expected
+
+
+@pytest.mark.parametrize(
+ "model",
+ [
+ "deepseek-ai/DeepSeek-V4-Flash",
+ "deepseek-ai/DeepSeek-V4-Pro",
+ ],
+)
+def test_deepseek_v4_official_api_defaults_to_thinking_for_v4_family(model):
+ from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ )
+
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": model,
+ "messages": [{"role": "user", "content": "Hello"}],
+ }
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs
+ )
+
+ assert chat_kwargs["thinking"] is True
+ assert chat_kwargs["enable_thinking"] is True
+
+
+def test_deepseek_v4_official_api_uses_model_config_for_family_detection():
+ from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ )
+
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": "local-ds4-alias",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "temperature": 0.2,
+ }
+ )
+ model_config = SimpleNamespace(
+ hf_config=SimpleNamespace(model_type="deepseek_v4", architectures=[]),
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs,
+ model_config=model_config,
+ )
+
+ sampling_params = request.to_sampling_params(
+ 16,
+ {},
+ chat_template_kwargs=chat_kwargs,
+ model_config=model_config,
+ )
+
+ assert chat_kwargs["thinking"] is True
+ assert chat_kwargs["enable_thinking"] is True
+ assert sampling_params.temperature == 1.0
+
+
+def test_deepseek_v4_official_api_sampling_override_can_be_disabled():
+ from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ )
+
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": "deepseek-ai/DeepSeek-V4-Flash",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "thinking": {"type": "enabled"},
+ "deepseek_v4_sampling_override": False,
+ "temperature": 0.2,
+ "top_p": 0.3,
+ "top_k": 4,
+ "min_p": 0.05,
+ "presence_penalty": 1.5,
+ "frequency_penalty": 1.25,
+ }
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs
+ )
+
+ sampling_params = request.to_sampling_params(
+ 16,
+ {},
+ chat_template_kwargs=chat_kwargs,
+ )
+
+ assert sampling_params.temperature == 0.2
+ assert sampling_params.top_p == 0.3
+ assert sampling_params.top_k == 4
+ assert sampling_params.min_p == 0.05
+ assert sampling_params.presence_penalty == 1.5
+ assert sampling_params.frequency_penalty == 1.25
+
+
+def test_deepseek_v4_official_api_sampling_override_is_v4_only():
+ from vllm.entrypoints.openai.chat_completion.protocol import (
+ ChatCompletionRequest,
+ )
+
+ request = ChatCompletionRequest.model_validate(
+ {
+ "model": "deepseek-ai/DeepSeek-R1",
+ "messages": [{"role": "user", "content": "Hello"}],
+ "thinking": {"type": "enabled"},
+ "temperature": 0.2,
+ "top_p": 0.3,
+ "top_k": 4,
+ "min_p": 0.05,
+ "presence_penalty": 1.5,
+ "frequency_penalty": 1.25,
+ }
+ )
+ chat_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(None, "auto").chat_template_kwargs
+ )
+
+ sampling_params = request.to_sampling_params(
+ 16,
+ {},
+ chat_template_kwargs=chat_kwargs,
+ )
+
+ assert "thinking" not in chat_kwargs
+ assert "enable_thinking" not in chat_kwargs
+ assert sampling_params.temperature == 0.2
+ assert sampling_params.top_p == 0.3
+ assert sampling_params.top_k == 4
+ assert sampling_params.min_p == 0.05
+ assert sampling_params.presence_penalty == 1.5
+ assert sampling_params.frequency_penalty == 1.25
diff --git a/tests/tools/test_compare_vllm_http_logprobs_oracle.py b/tests/tools/test_compare_vllm_http_logprobs_oracle.py
new file mode 100644
index 000000000000..729c76659d53
--- /dev/null
+++ b/tests/tools/test_compare_vllm_http_logprobs_oracle.py
@@ -0,0 +1,115 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import importlib.util
+from pathlib import Path
+
+SCRIPT_PATH = (
+ Path(__file__).parents[2] / "tools" / "compare_vllm_http_logprobs_oracle.py"
+)
+spec = importlib.util.spec_from_file_location(
+ "compare_vllm_http_logprobs_oracle", SCRIPT_PATH
+)
+assert spec is not None
+oracle_compare = importlib.util.module_from_spec(spec)
+assert spec.loader is not None
+spec.loader.exec_module(oracle_compare)
+
+
+def _response(tokens, top_logprobs):
+ token_ids = [
+ int(token.split(":", 1)[1])
+ for token in tokens
+ if isinstance(token, str) and token.startswith("token_id:")
+ ]
+ return {
+ "choices": [
+ {
+ "text": "",
+ "logprobs": {
+ "tokens": tokens,
+ "token_logprobs": [-0.1 * (i + 1) for i in range(len(tokens))],
+ "top_logprobs": top_logprobs,
+ },
+ "token_ids": token_ids,
+ "prompt_token_ids": [1, 2, 3],
+ }
+ ],
+ "usage": {"prompt_tokens": 3, "completion_tokens": len(tokens)},
+ }
+
+
+def test_compare_response_accepts_identical_top_logprobs():
+ top_logprobs = [
+ {"token_id:10": -0.1, "token_id:20": -1.0},
+ {"token_id:11": -0.2, "token_id:21": -1.2},
+ ]
+ report = oracle_compare.compare_response(
+ "case0",
+ _response(["token_id:10", "token_id:11"], top_logprobs),
+ _response(["token_id:10", "token_id:11"], top_logprobs),
+ top_n=2,
+ )
+
+ assert report["tokens_match"] is True
+ assert report["first_token_mismatch"] is None
+ assert report["top1_matches"] == 2
+ assert report["topk_overlap_mean"] == 1.0
+ assert report["max_common_logprob_abs_error"] == 0.0
+
+
+def test_compare_response_reports_first_generated_token_divergence():
+ oracle = _response(
+ ["token_id:10", "token_id:11"],
+ [
+ {"token_id:10": -0.1, "token_id:20": -1.0},
+ {"token_id:11": -0.2, "token_id:21": -1.2},
+ ],
+ )
+ actual = _response(
+ ["token_id:10", "token_id:99"],
+ [
+ {"token_id:10": -0.1, "token_id:20": -1.0},
+ {"token_id:99": -0.2, "token_id:21": -1.2},
+ ],
+ )
+
+ report = oracle_compare.compare_response("case0", oracle, actual, top_n=2)
+
+ assert report["tokens_match"] is False
+ assert report["first_token_mismatch"] == {
+ "step": 1,
+ "oracle": "token_id:11",
+ "actual": "token_id:99",
+ }
+ assert report["matching_prefix_tokens"] == 1
+ assert report["top1_matches"] == 1
+
+
+def test_compare_response_can_decode_oracle_token_id_keys():
+ normalizer = oracle_compare.TokenNormalizer(
+ lambda token_id: {10: '","', 11: "title", 20: " What"}[token_id]
+ )
+ oracle = _response(
+ ["token_id:10", "token_id:11"],
+ [
+ {"token_id:10": -0.1, "token_id:20": -1.0},
+ {"token_id:11": -0.2, "token_id:20": -1.2},
+ ],
+ )
+ actual = _response(
+ ['","', "title"],
+ [
+ {'","': -0.11, " What": -1.1},
+ {"title": -0.19, " What": -1.3},
+ ],
+ )
+
+ report = oracle_compare.compare_response(
+ "case0", oracle, actual, top_n=2, normalizer=normalizer
+ )
+
+ assert report["tokens_match"] is True
+ assert report["top1_matches"] == 2
+ assert report["topk_overlap_mean"] == 1.0
+ assert report["max_common_logprob_abs_error"] == 0.10000000000000009
diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py
new file mode 100644
index 000000000000..dbdc6f4e689c
--- /dev/null
+++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py
@@ -0,0 +1,3162 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Correctness tests for the DeepSeek V4 Triton sparse MLA path and reference oracle."""
+
+from types import SimpleNamespace
+
+import pytest
+import torch
+
+from vllm.config.compilation import (
+ CompilationConfig,
+ CompilationMode,
+ CUDAGraphMode,
+)
+from vllm.model_executor.layers import (
+ deepseek_v4_attention as deepseek_v4_attention_module,
+)
+from vllm.model_executor.layers.deepseek_v4_attention import (
+ _allocate_deepseek_v4_wo_a_output,
+ _deepseek_v4_fp8_einsum_config,
+ _sparse_mla_prefill_workspace_bounds,
+ deepseek_v4_fp8_einsum,
+)
+from vllm.utils.deep_gemm import fp8_einsum
+from vllm.v1.attention.backend import AttentionCGSupport
+from vllm.v1.attention.backends.mla.flashmla_sparse import (
+ FlashMLASparseMetadataBuilder,
+)
+from vllm.v1.attention.backends.mla.sparse_mla_env import (
+ disable_triton_sparse_mla_cudagraphs_if_enabled,
+ triton_sparse_mla_topk_chunk_size,
+)
+from vllm.v1.attention.backends.mla.sparse_mla_kernels import (
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk,
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead,
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk,
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead,
+ accumulate_gathered_sparse_mla_attention_chunk,
+ accumulate_indexed_sparse_mla_attention_chunk,
+ build_combined_sparse_mla_decode_valid_mask,
+ finish_gathered_sparse_mla_attention,
+ finish_materialized_sparse_mla_scores_with_sink,
+ finish_sparse_mla_attention_with_sink,
+ finish_two_sparse_mla_attention_states_with_sink,
+ fp8ds_global_paged_sparse_mla_attention_with_sink_multihead,
+ fp8ds_paged_sparse_mla_attention_with_sink_multihead,
+ matmul_sparse_mla_attention_with_sink,
+ merge_sparse_mla_subset_with_sink,
+ merge_two_sparse_mla_subsets_with_sink,
+ sparse_mla_decode_head_block_size,
+)
+from vllm.v1.attention.backends.mla.sparse_mla_reference import (
+ accumulate_reference_attention_chunk,
+ finish_reference_attention_no_sink,
+ merge_reference_attention_with_sink,
+ new_reference_attention_state,
+ reference_attention_no_sink,
+ reference_sparse_mla_prefill,
+ sink_aware_reference_attention,
+)
+from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadataBuilder
+from vllm.v1.attention.ops.deepseek_v4_ops import (
+ dequantize_and_gather_k_cache,
+ dequantize_combined_sparse_mla_decode_kv,
+ dequantize_global_slots_k_cache,
+)
+from vllm.v1.attention.ops.deepseek_v4_ops import fp8_einsum as fp8_einsum_module
+from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import (
+ deepseek_v4_sm12x_fp8_einsum,
+)
+from vllm.v1.kv_cache_interface import MLAAttentionSpec, SlidingWindowMLASpec
+
+_FP8_DIM = 448
+_ROPE_DIM = 64
+_SCALE_DIM = 8
+_TOKEN_DATA_SIZE = _FP8_DIM + _ROPE_DIM * 2
+
+
+class _FakeWorkspaceManager:
+ def get_simultaneous(self, *specs):
+ return tuple(torch.empty(shape, dtype=dtype) for shape, dtype in specs)
+
+
+def _assert_fp8_einsum_close(actual: torch.Tensor, expected: torch.Tensor) -> None:
+ # The Triton path and DeepGEMM reference both accumulate in FP32, but
+ # their reduction orders are not bit-identical before the final BF16 store.
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=5e-2, atol=3e-4)
+
+
+def test_deepseek_v4_fp8_einsum_is_piecewise_split_op() -> None:
+ assert "vllm::deepseek_v4_fp8_einsum" in CompilationConfig._attention_ops
+
+
+def test_wo_a_output_allocation_uses_empty_during_compile(monkeypatch) -> None:
+ class FailingWorkspaceManager:
+ def get_simultaneous(self, *args, **kwargs):
+ raise AssertionError("compiled allocation must not grow workspace")
+
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "current_workspace_manager",
+ lambda: FailingWorkspaceManager(),
+ )
+ monkeypatch.setattr(torch.compiler, "is_compiling", lambda: True)
+
+ output = _allocate_deepseek_v4_wo_a_output(
+ 2,
+ 3,
+ 5,
+ torch.bfloat16,
+ torch.device("cpu"),
+ )
+
+ assert output.shape == (2, 3, 5)
+ assert output.dtype == torch.bfloat16
+
+
+def test_wo_a_output_allocation_uses_workspace_outside_compile(monkeypatch) -> None:
+ captured = {}
+
+ class FakeWorkspaceManager:
+ def get_simultaneous(self, *shapes_and_dtypes):
+ captured["request"] = shapes_and_dtypes
+ return [
+ torch.empty(shape, dtype=dtype) for shape, dtype in shapes_and_dtypes
+ ]
+
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "current_workspace_manager",
+ lambda: FakeWorkspaceManager(),
+ )
+ monkeypatch.setattr(torch.compiler, "is_compiling", lambda: False)
+
+ output = _allocate_deepseek_v4_wo_a_output(
+ 2,
+ 3,
+ 5,
+ torch.bfloat16,
+ torch.device("cpu"),
+ )
+
+ assert captured["request"] == (((2, 3, 5), torch.bfloat16),)
+ assert output.shape == (2, 3, 5)
+ assert output.dtype == torch.bfloat16
+
+
+def test_dummy_attention_impl_reserves_prefill_workspace(monkeypatch) -> None:
+ class FakeMLAAttn:
+ def __init__(self) -> None:
+ self.reserved = False
+
+ def _reserve_prefill_workspace(self) -> None:
+ self.reserved = True
+
+ def __call__(self, *args, **kwargs) -> None:
+ raise AssertionError("dummy run must not execute real attention")
+
+ mla_attn = FakeMLAAttn()
+ layer = object.__new__(
+ deepseek_v4_attention_module.DeepseekV4MultiHeadLatentAttentionWrapper
+ )
+ layer.q_lora_rank = 2
+ layer.head_dim = 4
+ layer.n_local_heads = 2
+ layer.padded_heads = 64
+ layer.indexer = None
+ layer.compressor = None
+ layer.wq_b = lambda qr: torch.ones(qr.shape[0], 8)
+ layer.q_norm = SimpleNamespace(weight=SimpleNamespace(data=torch.empty(0)))
+ layer.kv_norm = SimpleNamespace(weight=SimpleNamespace(data=torch.empty(0)))
+ layer.eps = 1e-6
+ layer.mla_attn = mla_attn
+ layer.attn_gemm_parallel_execute = lambda hidden_states: (
+ torch.zeros(hidden_states.shape[0], 6),
+ None,
+ None,
+ None,
+ )
+ layer._fused_qnorm_rope_kv_insert = lambda *args, **kwargs: None
+
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "get_forward_context",
+ lambda: SimpleNamespace(attn_metadata=None),
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "fused_q_kv_rmsnorm",
+ lambda qr, kv, *args, **kwargs: (qr, kv),
+ )
+
+ out = torch.ones((3, 64, 4))
+ layer.attention_impl(
+ hidden_states=torch.zeros((3, 6)),
+ positions=torch.arange(3),
+ out=out,
+ )
+
+ assert mla_attn.reserved is True
+ assert torch.count_nonzero(out) == 0
+
+
+def test_prefill_workspace_reservation_specs_match_forward_prefill_bounds(
+ monkeypatch,
+) -> None:
+ attn = SimpleNamespace(
+ max_model_len=16_384,
+ max_num_batched_tokens=8192,
+ compress_ratio=4,
+ window_size=128,
+ head_dim=512,
+ num_heads=64,
+ topk_indices_buffer=torch.empty((8192, 2048), dtype=torch.int32),
+ indexer=None,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "is_triton_sparse_mla_enabled_for_platform",
+ lambda: True,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "triton_sparse_mla_query_chunk_size",
+ lambda: 256,
+ )
+
+ attention_cls = deepseek_v4_attention_module.DeepseekV4MLAAttention
+ specs = attention_cls._prefill_workspace_reservation_specs(attn)
+
+ assert specs == (
+ ((4, 12_415, 512), torch.bfloat16),
+ ((8192, 2176), torch.int32),
+ ((8192,), torch.int32),
+ ((256, 64), torch.float32),
+ ((256, 64), torch.float32),
+ ((256, 64, 512), torch.float32),
+ )
+
+
+def test_triton_sparse_mla_default_topk_chunk_size(monkeypatch) -> None:
+ monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False)
+
+ assert triton_sparse_mla_topk_chunk_size() == 512
+
+
+def test_sparse_mla_prefill_workspace_bounds_use_active_prefill_lengths() -> None:
+ seq_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32)
+ gather_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32)
+
+ compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds(
+ seq_lens_cpu=seq_lens_cpu,
+ gather_lens_cpu=gather_lens_cpu,
+ compress_ratio=4,
+ swa_only=False,
+ )
+
+ assert compressed_region_size == 3_750
+ assert row_stride == 18_750
+
+
+def test_sparse_mla_prefill_workspace_bounds_for_swa_only() -> None:
+ seq_lens_cpu = torch.tensor([15_000], dtype=torch.int32)
+ gather_lens_cpu = torch.tensor([15_000], dtype=torch.int32)
+
+ compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds(
+ seq_lens_cpu=seq_lens_cpu,
+ gather_lens_cpu=gather_lens_cpu,
+ compress_ratio=1,
+ swa_only=True,
+ )
+
+ assert compressed_region_size == 0
+ assert row_stride == 15_000
+
+
+@pytest.mark.parametrize(
+ ("num_decode_tokens", "expected_head_block_size"),
+ [
+ (0, 1),
+ (1, 1),
+ (4, 1),
+ (5, 2),
+ (8, 2),
+ (15, 2),
+ (16, 4),
+ (32, 4),
+ ],
+)
+def test_triton_sparse_mla_decode_head_block_size(
+ num_decode_tokens: int,
+ expected_head_block_size: int,
+ monkeypatch,
+) -> None:
+ monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", raising=False)
+
+ assert (
+ sparse_mla_decode_head_block_size(num_decode_tokens) == expected_head_block_size
+ )
+
+
+@pytest.mark.parametrize("configured_head_block_size", ["1", "2", "4"])
+def test_triton_sparse_mla_decode_head_block_size_env_override(
+ configured_head_block_size: str,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv(
+ "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE",
+ configured_head_block_size,
+ )
+
+ assert sparse_mla_decode_head_block_size(1) == int(configured_head_block_size)
+ assert sparse_mla_decode_head_block_size(32) == int(configured_head_block_size)
+
+
+@pytest.mark.parametrize("configured_head_block_size", ["0", "3", "invalid"])
+def test_triton_sparse_mla_decode_head_block_size_ignores_invalid_env_override(
+ configured_head_block_size: str,
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv(
+ "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE",
+ configured_head_block_size,
+ )
+
+ assert sparse_mla_decode_head_block_size(8) == 2
+
+
+def test_swa_mtp_decode_triton_uses_global_swa_slots(monkeypatch) -> None:
+ captured: dict[str, torch.Tensor] = {}
+
+ def fail_paged_attention_with_sink_multihead(**kwargs) -> None:
+ raise AssertionError("MTP SWA decode must use explicit SWA indices")
+
+ def fake_accumulate_global_slots(**kwargs) -> None:
+ captured["slot_ids"] = kwargs["slot_ids"]
+ captured["lens"] = kwargs["lens"]
+
+ def fake_finish_with_sink(*args, **kwargs) -> None:
+ kwargs["output"].zero_()
+
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "current_workspace_manager",
+ lambda: _FakeWorkspaceManager(),
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "fp8ds_paged_sparse_mla_attention_with_sink_multihead",
+ fail_paged_attention_with_sink_multihead,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead",
+ fake_accumulate_global_slots,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "finish_sparse_mla_attention_with_sink",
+ fake_finish_with_sink,
+ )
+
+ attention = SimpleNamespace(
+ num_heads=2,
+ scale=0.1,
+ attn_sink=torch.zeros(2, dtype=torch.float32),
+ )
+ swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8)
+ swa_lens = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32)
+ metadata = SimpleNamespace(
+ num_decodes=2,
+ num_decode_tokens=6,
+ decode_swa_lens=swa_lens,
+ decode_swa_indices=swa_indices,
+ seq_lens=torch.tensor([11, 22], dtype=torch.int32),
+ block_table=torch.empty((2, 4), dtype=torch.int32),
+ block_size=256,
+ token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32),
+ )
+
+ deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_swa_decode_triton(
+ attention,
+ q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16),
+ swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8),
+ swa_metadata=metadata,
+ output=torch.empty((6, 2, 512), dtype=torch.bfloat16),
+ )
+
+ torch.testing.assert_close(captured["slot_ids"], swa_indices)
+ torch.testing.assert_close(captured["lens"], swa_lens)
+
+
+def test_compressed_mtp_decode_triton_uses_global_swa_slots(monkeypatch) -> None:
+ captured: list[torch.Tensor] = []
+
+ def fail_matmul_decode(**kwargs) -> None:
+ raise AssertionError("MTP compressed decode must not stage paged SWA")
+
+ def fail_direct_global_paged(**kwargs) -> None:
+ raise AssertionError("MTP compressed decode must not use paged SWA window")
+
+ def fake_accumulate_global_slots(**kwargs) -> None:
+ captured.append(kwargs["slot_ids"])
+
+ def fake_finish_two_states(*args, **kwargs) -> None:
+ kwargs["output"].zero_()
+
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "current_workspace_manager",
+ lambda: _FakeWorkspaceManager(),
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "dequantize_combined_sparse_mla_decode_kv",
+ fail_matmul_decode,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "fp8ds_global_paged_sparse_mla_attention_with_sink_multihead",
+ fail_direct_global_paged,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead",
+ fake_accumulate_global_slots,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "finish_two_sparse_mla_attention_states_with_sink",
+ fake_finish_two_states,
+ )
+
+ attention = SimpleNamespace(
+ num_heads=2,
+ scale=0.1,
+ attn_sink=torch.zeros(2, dtype=torch.float32),
+ compress_ratio=4,
+ )
+ swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8)
+ topk_slot_ids = torch.arange(24, dtype=torch.int32).reshape(6, 1, 4)
+ swa_metadata = SimpleNamespace(
+ num_decodes=2,
+ num_decode_tokens=6,
+ decode_swa_lens=torch.full((6,), 3, dtype=torch.int32),
+ decode_swa_indices=swa_indices,
+ seq_lens=torch.tensor([11, 22], dtype=torch.int32),
+ block_table=torch.empty((2, 4), dtype=torch.int32),
+ block_size=256,
+ token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32),
+ )
+
+ deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_compressed_decode_triton(
+ attention,
+ q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16),
+ compressed_k_cache=torch.empty((1, 64, 584), dtype=torch.uint8),
+ swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8),
+ topk_indices=topk_slot_ids,
+ topk_lens=torch.full((6,), 4, dtype=torch.int32),
+ swa_metadata=swa_metadata,
+ attn_metadata=SimpleNamespace(block_size=256),
+ output=torch.empty((6, 2, 512), dtype=torch.bfloat16),
+ )
+
+ assert len(captured) == 2
+ torch.testing.assert_close(captured[0], topk_slot_ids[:, 0])
+ torch.testing.assert_close(captured[1], swa_indices)
+
+
+@pytest.mark.parametrize(
+ ("capability_major", "expected_recipe", "expected_tma_aligned"),
+ [
+ (9, (1, 128, 128), False),
+ (10, (1, 1, 128), True),
+ (12, (1, 128, 128), False),
+ ],
+)
+def test_deepseek_v4_fp8_einsum_config_for_sm12x(
+ capability_major: int,
+ expected_recipe: tuple[int, int, int],
+ expected_tma_aligned: bool,
+) -> None:
+ assert _deepseek_v4_fp8_einsum_config(capability_major) == (
+ expected_recipe,
+ expected_tma_aligned,
+ )
+
+
+def test_deepseek_v4_fp8_einsum_uses_sm12x_names() -> None:
+ assert hasattr(fp8_einsum_module, "_deepseek_v4_sm12x_fp8_einsum_kernel")
+ assert hasattr(fp8_einsum_module, "deepseek_v4_sm12x_fp8_einsum")
+ assert not hasattr(fp8_einsum_module, "_deepseek_v4_sm12_fp8_einsum_kernel")
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("use_e8m0_scale", [False, True])
+def test_deepseek_v4_sm12x_triton_fp8_einsum_matches_deepgemm_reference(
+ use_e8m0_scale: bool,
+) -> None:
+ if use_e8m0_scale and not hasattr(torch, "float8_e8m0fnu"):
+ pytest.skip("torch does not expose float8_e8m0fnu")
+ torch.manual_seed(0)
+ num_tokens = 17
+ num_groups = 4
+ hidden_size = 4096
+ out_rank = 1024
+ recipe = (1, 128, 128)
+
+ a_backing = torch.randn(
+ (num_groups, num_tokens, hidden_size),
+ device="cuda",
+ dtype=torch.bfloat16,
+ ).to(torch.float8_e4m3fn)
+ a = a_backing.transpose(0, 1)
+ a_scale_backing = torch.empty(
+ (num_groups, num_tokens, hidden_size // 128),
+ device="cuda",
+ dtype=torch.float32,
+ ).uniform_(0.01, 0.02)
+ a_scale = a_scale_backing.transpose(0, 1)
+ b_flat = torch.randn(
+ (num_groups * out_rank, hidden_size),
+ device="cuda",
+ dtype=torch.bfloat16,
+ ).to(torch.float8_e4m3fn)
+ b = b_flat.view(num_groups, out_rank, hidden_size)
+ if use_e8m0_scale:
+ scale_choices = torch.tensor(
+ [0.00390625, 0.0078125, 0.015625, 0.03125],
+ device="cuda",
+ dtype=torch.float32,
+ )
+ scale_indices = torch.randint(
+ 0,
+ len(scale_choices),
+ (num_groups * (out_rank // 128), hidden_size // 128),
+ device="cuda",
+ )
+ b_scale_flat = scale_choices[scale_indices].to(torch.float8_e8m0fnu)
+ b_scale_ref_flat = b_scale_flat.to(torch.float32)
+ else:
+ b_scale_flat = torch.empty(
+ (num_groups * (out_rank // 128), hidden_size // 128),
+ device="cuda",
+ dtype=torch.float32,
+ ).uniform_(0.01, 0.02)
+ b_scale_ref_flat = b_scale_flat
+ b_scale_ref = b_scale_ref_flat.view(num_groups, out_rank // 128, hidden_size // 128)
+ expected = torch.empty(
+ (num_tokens, num_groups, out_rank),
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ actual = torch.empty_like(expected)
+
+ fp8_einsum(
+ "bhr,hdr->bhd",
+ (a, a_scale),
+ (b, b_scale_ref),
+ expected,
+ recipe=recipe,
+ )
+ deepseek_v4_fp8_einsum(
+ a,
+ a_scale,
+ b_flat,
+ b_scale_flat,
+ actual,
+ "bhr,hdr->bhd",
+ list(recipe),
+ )
+
+ _assert_fp8_einsum_close(actual, expected)
+
+
+def test_deepseek_v4_fp8_einsum_slices_full_group_weight_for_tp(
+ monkeypatch,
+) -> None:
+ captured: dict[str, torch.Tensor] = {}
+ num_tokens = 2
+ local_groups = 4
+ full_groups = 8
+ out_rank = 512
+ hidden_size = 4096
+ recipe = [1, 128, 128]
+
+ def fake_sm12_fp8_einsum(
+ a: torch.Tensor,
+ a_scale: torch.Tensor,
+ b: torch.Tensor,
+ b_scale: torch.Tensor,
+ out: torch.Tensor,
+ ) -> None:
+ captured["b"] = b
+ captured["b_scale"] = b_scale
+
+ monkeypatch.setattr(
+ deepseek_v4_attention_module.current_platform,
+ "get_device_capability",
+ lambda: SimpleNamespace(major=12),
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "get_tensor_model_parallel_rank",
+ lambda: 1,
+ raising=False,
+ )
+ monkeypatch.setattr(
+ deepseek_v4_attention_module,
+ "deepseek_v4_sm12x_fp8_einsum",
+ fake_sm12_fp8_einsum,
+ )
+
+ a = torch.empty(
+ (num_tokens, local_groups, hidden_size),
+ dtype=torch.float8_e4m3fn,
+ )
+ a_scale = torch.empty(
+ (num_tokens, local_groups, hidden_size // 128),
+ dtype=torch.float32,
+ )
+ b = torch.arange(
+ full_groups * out_rank * hidden_size,
+ dtype=torch.uint8,
+ ).view(torch.float8_e4m3fn)
+ b = b.view(full_groups * out_rank, hidden_size)
+ b_scale = torch.arange(
+ full_groups * (out_rank // 128) * (hidden_size // 128),
+ dtype=torch.float32,
+ ).view(full_groups * (out_rank // 128), hidden_size // 128)
+ out = torch.empty((num_tokens, local_groups, out_rank), dtype=torch.bfloat16)
+
+ deepseek_v4_fp8_einsum(
+ a,
+ a_scale,
+ b,
+ b_scale,
+ out,
+ "bhr,hdr->bhd",
+ recipe,
+ )
+
+ expected_b = b.view(full_groups, out_rank, hidden_size)[local_groups:]
+ expected_b_scale = b_scale.view(full_groups, out_rank // 128, hidden_size // 128)[
+ local_groups:
+ ]
+ assert torch.equal(captured["b"].view(torch.uint8), expected_b.view(torch.uint8))
+ assert torch.equal(captured["b_scale"], expected_b_scale)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_deepseek_v4_sm12x_triton_fp8_einsum_primitive_matches_reference() -> None:
+ torch.manual_seed(0)
+ num_tokens = 17
+ num_groups = 4
+ hidden_size = 4096
+ out_rank = 1024
+ recipe = (1, 128, 128)
+
+ a_backing = torch.randn(
+ (num_groups, num_tokens, hidden_size),
+ device="cuda",
+ dtype=torch.bfloat16,
+ ).to(torch.float8_e4m3fn)
+ a = a_backing.transpose(0, 1)
+ a_scale_backing = torch.empty(
+ (num_groups, num_tokens, hidden_size // 128),
+ device="cuda",
+ dtype=torch.float32,
+ ).uniform_(0.01, 0.02)
+ a_scale = a_scale_backing.transpose(0, 1)
+ b_flat = torch.randn(
+ (num_groups * out_rank, hidden_size),
+ device="cuda",
+ dtype=torch.bfloat16,
+ ).to(torch.float8_e4m3fn)
+ b = b_flat.view(num_groups, out_rank, hidden_size)
+ b_scale_flat = torch.empty(
+ (num_groups * (out_rank // 128), hidden_size // 128),
+ device="cuda",
+ dtype=torch.float32,
+ ).uniform_(0.01, 0.02)
+ b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128)
+ expected = torch.empty(
+ (num_tokens, num_groups, out_rank),
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ actual = torch.empty_like(expected)
+
+ fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe)
+ deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, actual)
+
+ _assert_fp8_einsum_close(actual, expected)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("num_groups", [1, 2, 4])
+def test_deepseek_v4_sm12x_triton_fp8_einsum_supports_tp_local_group_counts(
+ num_groups: int,
+) -> None:
+ torch.manual_seed(18 + num_groups)
+ num_tokens = 5
+ hidden_size = 4096
+ out_rank = 1024
+ recipe = (1, 128, 128)
+
+ a_backing = torch.randn(
+ (num_groups, num_tokens, hidden_size),
+ device="cuda",
+ dtype=torch.bfloat16,
+ ).to(torch.float8_e4m3fn)
+ a = a_backing.transpose(0, 1)
+ a_scale_backing = torch.empty(
+ (num_groups, num_tokens, hidden_size // 128),
+ device="cuda",
+ dtype=torch.float32,
+ ).uniform_(0.01, 0.02)
+ a_scale = a_scale_backing.transpose(0, 1)
+ b_flat = torch.randn(
+ (num_groups * out_rank, hidden_size),
+ device="cuda",
+ dtype=torch.bfloat16,
+ ).to(torch.float8_e4m3fn)
+ b = b_flat.view(num_groups, out_rank, hidden_size)
+ b_scale_flat = torch.empty(
+ (num_groups * (out_rank // 128), hidden_size // 128),
+ device="cuda",
+ dtype=torch.float32,
+ ).uniform_(0.01, 0.02)
+ b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128)
+ expected = torch.empty(
+ (num_tokens, num_groups, out_rank),
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ actual = torch.empty_like(expected)
+
+ fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe)
+ deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, actual)
+
+ _assert_fp8_einsum_close(actual, expected)
+
+
+def _masked_scores(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+) -> torch.Tensor:
+ q_bhd = q[:, 0].float() if q.dim() == 4 else q.float()
+ scores = torch.einsum("bhd,btd->bht", q_bhd, kv.float()) * scale
+ return scores.masked_fill(~valid_tokens[:, None, :], float("-inf"))
+
+
+def _golden_no_sink_attention(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ scores = _masked_scores(q, kv, valid_tokens, scale)
+ lse = torch.logsumexp(scores, dim=-1)
+ weights = torch.exp(scores - lse[:, :, None])
+ weights = torch.where(
+ valid_tokens[:, None, :],
+ weights,
+ torch.zeros((), dtype=weights.dtype, device=weights.device),
+ )
+ weights = torch.nan_to_num(weights)
+ output = torch.einsum("bht,btd->bhd", weights, kv.float())
+ valid = valid_tokens.any(dim=-1)
+ output = torch.where(
+ valid[:, None, None],
+ output,
+ torch.zeros((), dtype=output.dtype, device=output.device),
+ )
+ return output, lse
+
+
+def _golden_sink_attention(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+ attn_sink: torch.Tensor,
+) -> torch.Tensor:
+ scores = _masked_scores(q, kv, valid_tokens, scale)
+ sink = attn_sink[None, :].float()
+ score_max = scores.amax(dim=-1)
+ merge_max = torch.maximum(score_max, sink)
+
+ weights = torch.exp(scores - merge_max[:, :, None])
+ weights = torch.where(
+ valid_tokens[:, None, :],
+ weights,
+ torch.zeros((), dtype=weights.dtype, device=weights.device),
+ )
+ weights = torch.nan_to_num(weights)
+
+ sink_weight = torch.exp(sink - merge_max)
+ sink_weight = torch.nan_to_num(sink_weight)
+ denom = weights.sum(dim=-1) + sink_weight
+ numerator = torch.einsum("bht,btd->bhd", weights, kv.float())
+ return numerator / denom[:, :, None]
+
+
+def _chunked_no_sink_attention(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+ chunk_size: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ q_bhd, max_score, denom, acc = new_reference_attention_state(q)
+ for chunk_start in range(0, kv.shape[1], chunk_size):
+ chunk_end = min(chunk_start + chunk_size, kv.shape[1])
+ max_score, denom, acc = accumulate_reference_attention_chunk(
+ q_bhd=q_bhd,
+ kv=kv[:, chunk_start:chunk_end],
+ valid_tokens=valid_tokens[:, chunk_start:chunk_end],
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ scale=scale,
+ )
+ return finish_reference_attention_no_sink(max_score, denom, acc)
+
+
+def _write_fp8_ds_mla_token(
+ k_cache: torch.Tensor,
+ slot: int,
+ block_size: int,
+) -> torch.Tensor:
+ block_idx = slot // block_size
+ block_offset = slot % block_size
+
+ values = (
+ (torch.arange(_FP8_DIM, device=k_cache.device, dtype=torch.float32) % 17) - 8
+ ) / 16.0
+ values = values + float(slot) / 32.0
+ scale_exponents = torch.tensor(
+ [-2, -1, 0, 1, 2, -2, 1],
+ device=k_cache.device,
+ dtype=torch.float32,
+ )
+ scales = torch.exp2(scale_exponents)
+ scale_per_dim = scales.repeat_interleave(64)
+
+ fp8_values = (values / scale_per_dim).to(torch.float8_e4m3fn)
+ expected_nope = fp8_values.float() * scale_per_dim
+ rope = (
+ torch.linspace(-1.0, 1.0, _ROPE_DIM, device=k_cache.device) + float(slot) / 16.0
+ ).to(torch.bfloat16)
+
+ flat_block = k_cache[block_idx].view(-1)
+ token_data_start = block_offset * _TOKEN_DATA_SIZE
+ token_scale_start = block_size * _TOKEN_DATA_SIZE + block_offset * _SCALE_DIM
+ flat_block[token_data_start : token_data_start + _FP8_DIM] = fp8_values.view(
+ torch.uint8
+ )
+ flat_block[token_data_start + _FP8_DIM : token_data_start + _TOKEN_DATA_SIZE] = (
+ rope.view(torch.uint8)
+ )
+
+ encoded_scales = (scale_exponents.to(torch.int32) + 127).to(torch.uint8)
+ flat_block[token_scale_start : token_scale_start + encoded_scales.numel()] = (
+ encoded_scales
+ )
+ flat_block[
+ token_scale_start + encoded_scales.numel() : token_scale_start + _SCALE_DIM
+ ] = 127
+
+ return torch.cat([expected_nope, rope.float()]).to(torch.bfloat16)
+
+
+def _materialize_global_fp8_ds_mla_slots(
+ k_cache: torch.Tensor,
+ slot_ids: torch.Tensor,
+ block_size: int,
+) -> torch.Tensor:
+ gathered = torch.zeros(
+ *slot_ids.shape,
+ 512,
+ dtype=torch.bfloat16,
+ device=k_cache.device,
+ )
+ for token_idx, row in enumerate(slot_ids.detach().cpu().tolist()):
+ for candidate_idx, slot in enumerate(row):
+ if slot >= 0:
+ gathered[token_idx, candidate_idx] = _write_fp8_ds_mla_token(
+ k_cache,
+ slot,
+ block_size,
+ )
+ return gathered
+
+
+def _materialize_paged_fp8_ds_mla_window(
+ k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ gather_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ block_size: int,
+) -> torch.Tensor:
+ seq_lens_cpu = seq_lens.detach().cpu().tolist()
+ gather_lens_cpu = gather_lens.detach().cpu().tolist()
+ block_table_cpu = block_table.detach().cpu().tolist()
+ max_gather_len = max(gather_lens_cpu)
+ gathered = torch.zeros(
+ len(seq_lens_cpu),
+ max_gather_len,
+ 512,
+ dtype=torch.bfloat16,
+ device=k_cache.device,
+ )
+ for token_idx, (seq_len, gather_len) in enumerate(
+ zip(seq_lens_cpu, gather_lens_cpu)
+ ):
+ start_pos = seq_len - gather_len
+ for gather_idx in range(gather_len):
+ logical_pos = start_pos + gather_idx
+ logical_block = logical_pos // block_size
+ block_offset = logical_pos % block_size
+ physical_block = block_table_cpu[token_idx][logical_block]
+ physical_slot = physical_block * block_size + block_offset
+ gathered[token_idx, gather_idx] = _write_fp8_ds_mla_token(
+ k_cache,
+ physical_slot,
+ block_size,
+ )
+ return gathered
+
+
+def test_reference_attention_no_sink_matches_logsumexp() -> None:
+ torch.manual_seed(0)
+ scale = 0.25
+ q = torch.randn(3, 4, 5)
+ kv = torch.randn(3, 6, 5)
+ valid_tokens = torch.tensor(
+ [
+ [True, True, False, True, False, False],
+ [False, False, False, False, False, False],
+ [True, False, True, True, True, False],
+ ],
+ dtype=torch.bool,
+ )
+ output, lse = reference_attention_no_sink(q, kv, valid_tokens, scale)
+ expected_output, expected_lse = _golden_no_sink_attention(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ )
+
+ torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6)
+ torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6)
+
+
+def test_reference_attention_ignores_nan_kv_for_invalid_tokens() -> None:
+ torch.manual_seed(24)
+ q = torch.randn(2, 1, 3, 8)
+ kv = torch.randn(2, 4, 8)
+ kv[:, 2:] = float("nan")
+ valid_tokens = torch.tensor(
+ [[True, True, False, False], [True, False, False, False]],
+ dtype=torch.bool,
+ )
+
+ output, lse = reference_attention_no_sink(
+ q=q,
+ kv=kv,
+ valid_tokens=valid_tokens,
+ scale=0.125,
+ )
+
+ assert torch.isfinite(output).all()
+ assert torch.isfinite(lse).all()
+
+
+def test_sink_aware_reference_attention_matches_dense_golden() -> None:
+ torch.manual_seed(1)
+ scale = 0.125
+ q = torch.randn(3, 1, 4, 5)
+ kv = torch.randn(3, 6, 5)
+ valid_tokens = torch.tensor(
+ [
+ [True, True, False, True, False, False],
+ [False, False, False, False, False, False],
+ [False, True, True, False, True, True],
+ ],
+ dtype=torch.bool,
+ )
+ sink = torch.tensor([-1.0, 0.25, 1.5, -0.5])
+ output = torch.empty(3, 4, 5)
+ sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, output)
+ expected = _golden_sink_attention(q, kv, valid_tokens, scale, sink)
+
+ torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6)
+
+
+def test_lse_merge_with_sink_matches_concatenated_attention() -> None:
+ torch.manual_seed(2)
+ scale = 0.2
+ q = torch.randn(4, 3, 7)
+ compressed_kv = torch.randn(4, 5, 7)
+ swa_kv = torch.randn(4, 3, 7)
+ compressed_kv[:, 1] = compressed_kv[:, 0]
+ swa_kv[:, 2] = compressed_kv[:, 0]
+ compressed_valid = torch.tensor(
+ [
+ [True, True, False, True, False],
+ [False, False, False, False, False],
+ [True, False, True, True, False],
+ [False, False, False, False, False],
+ ],
+ dtype=torch.bool,
+ )
+ swa_valid = torch.tensor(
+ [
+ [True, False, True],
+ [True, True, False],
+ [False, False, False],
+ [False, False, False],
+ ],
+ dtype=torch.bool,
+ )
+ sink = torch.tensor([-0.25, 0.75, 1.25])
+ output = torch.empty(4, 3, 7)
+ comp_output, comp_lse = reference_attention_no_sink(
+ q,
+ compressed_kv,
+ compressed_valid,
+ scale,
+ )
+ swa_output, swa_lse = reference_attention_no_sink(q, swa_kv, swa_valid, scale)
+ merge_reference_attention_with_sink(
+ subset_outputs=[comp_output, swa_output],
+ subset_lses=[comp_lse, swa_lse],
+ attn_sink=sink,
+ output=output,
+ )
+
+ expected = _golden_sink_attention(
+ q,
+ torch.cat([compressed_kv, swa_kv], dim=1),
+ torch.cat([compressed_valid, swa_valid], dim=1),
+ scale,
+ sink,
+ )
+ torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6)
+ assert torch.equal(output[3], torch.zeros_like(output[3]))
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_lse_merge_with_sink_matches_reference() -> None:
+ torch.manual_seed(5)
+ comp_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32)
+ swa_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32)
+ comp_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32)
+ swa_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32)
+ comp_lse[1, 2] = float("-inf")
+ swa_lse[2, 1] = float("-inf")
+ sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda")
+
+ output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16)
+ expected = torch.empty_like(output)
+ merge_two_sparse_mla_subsets_with_sink(
+ subset0_output=comp_output,
+ subset0_lse=comp_lse,
+ subset1_output=swa_output,
+ subset1_lse=swa_lse,
+ attn_sink=sink,
+ output=output,
+ )
+ merge_reference_attention_with_sink(
+ subset_outputs=[comp_output, swa_output],
+ subset_lses=[comp_lse, swa_lse],
+ attn_sink=sink,
+ output=expected,
+ )
+
+ torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_single_lse_merge_with_sink_matches_reference() -> None:
+ torch.manual_seed(14)
+ subset_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32)
+ subset_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32)
+ subset_lse[1, 2] = float("-inf")
+ sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda")
+
+ output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16)
+ expected = torch.empty_like(output)
+ merge_sparse_mla_subset_with_sink(
+ subset_output=subset_output,
+ subset_lse=subset_lse,
+ attn_sink=sink,
+ output=output,
+ )
+ merge_reference_attention_with_sink(
+ subset_outputs=[subset_output],
+ subset_lses=[subset_lse],
+ attn_sink=sink,
+ output=expected,
+ )
+
+ torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_finish_with_sink_matches_finish_then_merge_reference() -> None:
+ torch.manual_seed(18)
+ max_score = torch.randn(4, 3, device="cuda", dtype=torch.float32)
+ denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1
+ denom[1, 2] = 0.0
+ max_score[1, 2] = float("-inf")
+ acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32)
+ sink = torch.tensor(
+ [-0.5, 0.25, 1.0, -float("inf"), -float("inf")],
+ device="cuda",
+ dtype=torch.float32,
+ )
+
+ output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16)
+ finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, output)
+
+ subset_output = torch.empty_like(acc)
+ subset_lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=subset_output,
+ lse=subset_lse,
+ )
+ expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16)
+ merge_reference_attention_with_sink(
+ subset_outputs=[subset_output],
+ subset_lses=[subset_lse],
+ attn_sink=sink[:3],
+ output=expected,
+ )
+
+ torch.testing.assert_close(
+ output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2
+ )
+ torch.testing.assert_close(
+ output[:, 3:].float(),
+ torch.full_like(output[:, 3:].float(), -7.0),
+ rtol=0,
+ atol=0,
+ )
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_finish_with_sink_returns_zero_when_no_tokens_or_sink() -> None:
+ max_score = torch.full((2, 3), float("-inf"), device="cuda")
+ denom = torch.zeros((2, 3), device="cuda")
+ acc = torch.full((2, 3, 17), float("nan"), device="cuda")
+ sink = torch.full((3,), float("-inf"), device="cuda")
+
+ single_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16)
+ finish_sparse_mla_attention_with_sink(
+ max_score,
+ denom,
+ acc,
+ sink,
+ output=single_output,
+ )
+ torch.testing.assert_close(
+ single_output.float(),
+ torch.zeros_like(single_output.float()),
+ rtol=0,
+ atol=0,
+ )
+
+ two_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16)
+ finish_two_sparse_mla_attention_states_with_sink(
+ max_score,
+ denom,
+ acc,
+ max_score,
+ denom,
+ acc,
+ sink,
+ output=two_output,
+ )
+ torch.testing.assert_close(
+ two_output.float(),
+ torch.zeros_like(two_output.float()),
+ rtol=0,
+ atol=0,
+ )
+
+
+def test_triton_finish_two_states_with_sink_matches_finish_then_merge() -> None:
+ torch.manual_seed(22)
+ comp_max = torch.randn(4, 3, device="cuda", dtype=torch.float32)
+ comp_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1
+ comp_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32)
+ swa_max = torch.randn(4, 3, device="cuda", dtype=torch.float32)
+ swa_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1
+ swa_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32)
+ sink = torch.tensor(
+ [-0.5, 0.25, 1.0, -float("inf"), -float("inf")],
+ device="cuda",
+ dtype=torch.float32,
+ )
+
+ comp_denom[0, 1] = 0.0
+ comp_max[0, 1] = float("-inf")
+ swa_denom[2, 0] = 0.0
+ swa_max[2, 0] = float("-inf")
+ comp_denom[3, 2] = 0.0
+ comp_max[3, 2] = float("-inf")
+ swa_denom[3, 2] = 0.0
+ swa_max[3, 2] = float("-inf")
+
+ output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16)
+ finish_two_sparse_mla_attention_states_with_sink(
+ comp_max,
+ comp_denom,
+ comp_acc,
+ swa_max,
+ swa_denom,
+ swa_acc,
+ sink,
+ output,
+ )
+
+ comp_output = torch.empty_like(comp_acc)
+ comp_lse = torch.empty_like(comp_max)
+ swa_output = torch.empty_like(swa_acc)
+ swa_lse = torch.empty_like(swa_max)
+ finish_gathered_sparse_mla_attention(
+ comp_max,
+ comp_denom,
+ comp_acc,
+ comp_output,
+ comp_lse,
+ )
+ finish_gathered_sparse_mla_attention(
+ swa_max,
+ swa_denom,
+ swa_acc,
+ swa_output,
+ swa_lse,
+ )
+ expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16)
+ merge_two_sparse_mla_subsets_with_sink(
+ subset0_output=comp_output,
+ subset0_lse=comp_lse,
+ subset1_output=swa_output,
+ subset1_lse=swa_lse,
+ attn_sink=sink[:3],
+ output=expected,
+ )
+
+ torch.testing.assert_close(
+ output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2
+ )
+ torch.testing.assert_close(
+ output[:, 3:].float(),
+ torch.full_like(output[:, 3:].float(), -7.0),
+ rtol=0,
+ atol=0,
+ )
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("head_dim", [16, 512])
+def test_triton_gathered_attention_chunk_matches_reference(head_dim: int) -> None:
+ torch.manual_seed(6)
+ scale = 0.125
+ q = torch.randn(2, 1, 5, head_dim, device="cuda", dtype=torch.bfloat16)
+ q_active = q[:, :, :3]
+ kv = torch.randn(2, 5, head_dim, device="cuda", dtype=torch.bfloat16)
+ slot_ids = torch.tensor(
+ [
+ [0, 1, -1, 3, 4],
+ [5, -1, 7, 8, -1],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda")
+ max_score = torch.full((2, 3), float("-inf"), device="cuda")
+ denom = torch.zeros((2, 3), device="cuda")
+ acc = torch.zeros((2, 3, head_dim), device="cuda")
+
+ accumulate_gathered_sparse_mla_attention_chunk(
+ q=q,
+ kv=kv[:, :2],
+ slot_ids=slot_ids[:, :2],
+ lens=lens,
+ candidate_offset=0,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+ accumulate_gathered_sparse_mla_attention_chunk(
+ q=q,
+ kv=kv[:, 2:],
+ slot_ids=slot_ids[:, 2:],
+ lens=lens,
+ candidate_offset=2,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+
+ output = torch.empty_like(acc)
+ lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=output,
+ lse=lse,
+ )
+
+ offsets = torch.arange(slot_ids.shape[1], device="cuda")
+ valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0)
+ expected_output, expected_lse = reference_attention_no_sink(
+ q_active,
+ kv,
+ valid_tokens,
+ scale,
+ )
+ torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_gathered_attention_chunk_matches_reference_without_slot_ids() -> None:
+ torch.manual_seed(8)
+ scale = 0.2
+ q = torch.randn(3, 1, 2, 32, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(3, 6, 32, device="cuda", dtype=torch.bfloat16)
+ lens = torch.tensor([6, 3, 0], dtype=torch.int32, device="cuda")
+ max_score = torch.full((3, 2), float("-inf"), device="cuda")
+ denom = torch.zeros((3, 2), device="cuda")
+ acc = torch.zeros((3, 2, 32), device="cuda")
+
+ accumulate_gathered_sparse_mla_attention_chunk(
+ q=q,
+ kv=kv,
+ slot_ids=None,
+ lens=lens,
+ candidate_offset=0,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+
+ output = torch.empty_like(acc)
+ lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=output,
+ lse=lse,
+ )
+
+ offsets = torch.arange(kv.shape[1], device="cuda")
+ valid_tokens = offsets[None, :] < lens[:, None]
+ expected_output, expected_lse = reference_attention_no_sink(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ )
+ torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_dequantize_global_slots_k_cache_fp8_ds_mla_layout() -> None:
+ block_size = 4
+ num_blocks = 2
+ k_cache = torch.zeros(
+ num_blocks,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ expected_by_slot = {
+ slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) for slot in (0, 3, 4)
+ }
+ slot_ids = torch.tensor(
+ [
+ [0, 3, -1, 4],
+ [4, 0, 3, -1],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+
+ output = torch.empty(2, 4, 512, dtype=torch.bfloat16, device="cuda")
+ dequantize_global_slots_k_cache(output, k_cache, slot_ids, block_size)
+
+ expected = torch.zeros_like(output)
+ for token_idx in range(slot_ids.shape[0]):
+ for topk_idx in range(slot_ids.shape[1]):
+ slot = int(slot_ids[token_idx, topk_idx].item())
+ if slot >= 0:
+ expected[token_idx, topk_idx] = expected_by_slot[slot]
+
+ torch.testing.assert_close(output.float(), expected.float(), rtol=0, atol=0)
+
+ output_from_3d_indices = torch.empty_like(output)
+ dequantize_global_slots_k_cache(
+ output_from_3d_indices,
+ k_cache,
+ slot_ids.unsqueeze(1),
+ block_size,
+ )
+ torch.testing.assert_close(
+ output_from_3d_indices.float(),
+ expected.float(),
+ rtol=0,
+ atol=0,
+ )
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_dequantize_combined_sparse_mla_decode_kv_writes_direct_views() -> None:
+ compressed_block_size = 4
+ swa_block_size = 4
+ compressed_cache = torch.zeros(
+ 2,
+ compressed_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ swa_cache = torch.zeros(
+ 3,
+ swa_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ for slot in (0, 3, 4):
+ _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size)
+ for slot in (0, 1, 2, 3, 4):
+ _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size)
+
+ compressed_slot_ids = torch.tensor(
+ [[0, 3, -1], [4, 0, 3]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([5, 7], dtype=torch.int32, device="cuda")
+ swa_lens = torch.tensor([2, 3], dtype=torch.int32, device="cuda")
+ block_table = torch.tensor(
+ [[0, 1, 2], [2, 0, 1]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+
+ combined = torch.full(
+ (2, 6, 512),
+ -7,
+ dtype=torch.bfloat16,
+ device="cuda",
+ )
+ dequantize_combined_sparse_mla_decode_kv(
+ combined,
+ compressed_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ swa_cache,
+ seq_lens,
+ swa_lens,
+ block_table,
+ swa_block_size,
+ )
+
+ expected_comp = torch.empty(2, 3, 512, dtype=torch.bfloat16, device="cuda")
+ expected_swa = torch.full(
+ (2, 3, 512),
+ -7,
+ dtype=torch.bfloat16,
+ device="cuda",
+ )
+ dequantize_global_slots_k_cache(
+ expected_comp,
+ compressed_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ )
+ dequantize_and_gather_k_cache(
+ expected_swa,
+ swa_cache,
+ seq_lens=seq_lens,
+ gather_lens=swa_lens,
+ block_table=block_table,
+ block_size=swa_block_size,
+ offset=0,
+ )
+ expected = torch.full_like(combined, -7)
+ expected[:, :3].copy_(expected_comp)
+ expected[:, 3:].copy_(expected_swa)
+
+ torch.testing.assert_close(combined.float(), expected.float(), rtol=0, atol=0)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_fp8ds_global_slots_attention_chunk_matches_reference() -> None:
+ torch.manual_seed(10)
+ block_size = 4
+ num_blocks = 3
+ k_cache = torch.zeros(
+ num_blocks,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ expected_by_slot = {
+ slot: _write_fp8_ds_mla_token(k_cache, slot, block_size)
+ for slot in (0, 1, 3, 4, 7, 8)
+ }
+ slot_ids = torch.tensor(
+ [
+ [0, 3, -1, 8, 1],
+ [7, -1, 4, 0, 8],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16)
+ scale = 0.0625
+
+ max_score = torch.full((2, 3), float("-inf"), device="cuda")
+ denom = torch.zeros((2, 3), device="cuda")
+ acc = torch.zeros((2, 3, 512), device="cuda")
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk(
+ q=q,
+ k_cache=k_cache,
+ slot_ids=slot_ids[:, :2],
+ lens=lens,
+ block_size=block_size,
+ candidate_offset=0,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk(
+ q=q,
+ k_cache=k_cache,
+ slot_ids=slot_ids[:, 2:],
+ lens=lens,
+ block_size=block_size,
+ candidate_offset=2,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+
+ output = torch.empty_like(acc)
+ lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=output,
+ lse=lse,
+ )
+
+ gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ for token_idx in range(slot_ids.shape[0]):
+ for topk_idx in range(slot_ids.shape[1]):
+ slot = int(slot_ids[token_idx, topk_idx].item())
+ if slot >= 0:
+ gathered[token_idx, topk_idx] = expected_by_slot[slot]
+ offsets = torch.arange(slot_ids.shape[1], device="cuda")
+ valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0)
+ expected_output, expected_lse = reference_attention_no_sink(
+ q,
+ gathered,
+ valid_tokens,
+ scale,
+ )
+
+ torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("head_block_size", [1, 2, 4])
+def test_triton_fp8ds_global_slots_multihead_attention_matches_reference(
+ head_block_size: int,
+) -> None:
+ torch.manual_seed(19)
+ block_size = 4
+ num_blocks = 3
+ k_cache = torch.zeros(
+ num_blocks,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ expected_by_slot = {
+ slot: _write_fp8_ds_mla_token(k_cache, slot, block_size)
+ for slot in (0, 1, 3, 4, 7, 8)
+ }
+ slot_ids = torch.tensor(
+ [
+ [0, 3, -1, 8, 1],
+ [7, -1, 4, 0, 8],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16)
+ q_active = q[:, :, :5]
+ scale = 0.0625
+
+ max_score = torch.full((2, 5), float("-inf"), device="cuda")
+ denom = torch.zeros((2, 5), device="cuda")
+ acc = torch.zeros((2, 5, 512), device="cuda")
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=k_cache,
+ slot_ids=slot_ids[:, :2],
+ lens=lens,
+ block_size=block_size,
+ candidate_offset=0,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ head_block_size=head_block_size,
+ )
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=k_cache,
+ slot_ids=slot_ids[:, 2:],
+ lens=lens,
+ block_size=block_size,
+ candidate_offset=2,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ head_block_size=head_block_size,
+ )
+
+ output = torch.empty_like(acc)
+ lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=output,
+ lse=lse,
+ )
+
+ gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ for token_idx in range(slot_ids.shape[0]):
+ for topk_idx in range(slot_ids.shape[1]):
+ slot = int(slot_ids[token_idx, topk_idx].item())
+ if slot >= 0:
+ gathered[token_idx, topk_idx] = expected_by_slot[slot]
+ offsets = torch.arange(slot_ids.shape[1], device="cuda")
+ valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0)
+ expected_output, expected_lse = reference_attention_no_sink(
+ q_active,
+ gathered,
+ valid_tokens,
+ scale,
+ )
+
+ torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_fp8ds_paged_attention_chunk_matches_reference() -> None:
+ torch.manual_seed(12)
+ block_size = 4
+ k_cache = torch.zeros(
+ 3,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ block_table = torch.tensor(
+ [
+ [1, 0, 2],
+ [2, 1, 0],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([6, 9], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([3, 4], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16)
+ scale = 0.0625
+
+ gathered = torch.zeros(2, 4, 512, device="cuda", dtype=torch.bfloat16)
+ expected_by_slot: dict[int, torch.Tensor] = {}
+ for token_idx in range(seq_lens.shape[0]):
+ start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item())
+ for gather_idx in range(int(gather_lens[token_idx].item())):
+ pos = start_pos + gather_idx
+ block_idx = pos // block_size
+ block_offset = pos % block_size
+ physical_block = int(block_table[token_idx, block_idx].item())
+ slot = physical_block * block_size + block_offset
+ expected_by_slot.setdefault(
+ slot,
+ _write_fp8_ds_mla_token(k_cache, slot, block_size),
+ )
+ gathered[token_idx, gather_idx] = expected_by_slot[slot]
+
+ max_score = torch.full((2, 3), float("-inf"), device="cuda")
+ denom = torch.zeros((2, 3), device="cuda")
+ acc = torch.zeros((2, 3, 512), device="cuda")
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=0,
+ num_candidates=2,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=2,
+ num_candidates=2,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+
+ output = torch.empty_like(acc)
+ lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=output,
+ lse=lse,
+ )
+
+ offsets = torch.arange(gathered.shape[1], device="cuda")
+ valid_tokens = offsets[None, :] < gather_lens[:, None]
+ expected_output, expected_lse = reference_attention_no_sink(
+ q,
+ gathered,
+ valid_tokens,
+ scale,
+ )
+
+ torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("head_block_size", [1, 2, 4])
+def test_triton_fp8ds_paged_multihead_attention_matches_singlehead_and_reference(
+ head_block_size: int,
+) -> None:
+ torch.manual_seed(23)
+ block_size = 4
+ k_cache = torch.zeros(
+ 4,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ block_table = torch.tensor(
+ [
+ [1, 0, 2, 3],
+ [2, 3, 1, 0],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16)
+ q_active = q[:, :, :5]
+ scale = 0.0625
+
+ gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ expected_by_slot: dict[int, torch.Tensor] = {}
+ for token_idx in range(seq_lens.shape[0]):
+ start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item())
+ for gather_idx in range(int(gather_lens[token_idx].item())):
+ pos = start_pos + gather_idx
+ block_idx = pos // block_size
+ block_offset = pos % block_size
+ physical_block = int(block_table[token_idx, block_idx].item())
+ slot = physical_block * block_size + block_offset
+ expected_by_slot.setdefault(
+ slot,
+ _write_fp8_ds_mla_token(k_cache, slot, block_size),
+ )
+ gathered[token_idx, gather_idx] = expected_by_slot[slot]
+
+ single_max = torch.full((2, 5), float("-inf"), device="cuda")
+ single_denom = torch.zeros((2, 5), device="cuda")
+ single_acc = torch.zeros((2, 5, 512), device="cuda")
+ multi_max = torch.full_like(single_max, float("-inf"))
+ multi_denom = torch.zeros_like(single_denom)
+ multi_acc = torch.zeros_like(single_acc)
+
+ for candidate_offset, num_candidates in ((0, 2), (2, 3)):
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=candidate_offset,
+ num_candidates=num_candidates,
+ scale=scale,
+ max_score=single_max,
+ denom=single_denom,
+ acc=single_acc,
+ )
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=candidate_offset,
+ num_candidates=num_candidates,
+ scale=scale,
+ max_score=multi_max,
+ denom=multi_denom,
+ acc=multi_acc,
+ head_block_size=head_block_size,
+ )
+
+ torch.testing.assert_close(multi_max, single_max, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(multi_denom, single_denom, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(multi_acc, single_acc, rtol=2e-2, atol=2e-2)
+
+ output = torch.empty_like(multi_acc)
+ lse = torch.empty_like(multi_max)
+ finish_gathered_sparse_mla_attention(
+ max_score=multi_max,
+ denom=multi_denom,
+ acc=multi_acc,
+ output=output,
+ lse=lse,
+ )
+ offsets = torch.arange(gathered.shape[1], device="cuda")
+ valid_tokens = offsets[None, :] < gather_lens[:, None]
+ expected_output, expected_lse = reference_attention_no_sink(
+ q_active,
+ gathered,
+ valid_tokens,
+ scale,
+ )
+
+ torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2)
+ torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_fp8ds_paged_attention_with_sink_matches_reference() -> None:
+ torch.manual_seed(15)
+ block_size = 4
+ k_cache = torch.zeros(
+ 3,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ block_table = torch.tensor([[1, 0, 2]], dtype=torch.int32, device="cuda")
+ seq_lens = torch.tensor([7], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([4], dtype=torch.int32, device="cuda")
+ q = torch.randn(1, 1, 3, 512, device="cuda", dtype=torch.bfloat16)
+ sink = torch.tensor([-0.25, 0.5, 1.25], device="cuda")
+ scale = 0.0625
+
+ gathered = torch.zeros(1, 4, 512, device="cuda", dtype=torch.bfloat16)
+ expected_by_slot: dict[int, torch.Tensor] = {}
+ start_pos = int(seq_lens[0].item() - gather_lens[0].item())
+ for gather_idx in range(int(gather_lens[0].item())):
+ pos = start_pos + gather_idx
+ physical_block = int(block_table[0, pos // block_size].item())
+ slot = physical_block * block_size + pos % block_size
+ expected_by_slot.setdefault(
+ slot,
+ _write_fp8_ds_mla_token(k_cache, slot, block_size),
+ )
+ gathered[0, gather_idx] = expected_by_slot[slot]
+
+ max_score = torch.full((1, 3), float("-inf"), device="cuda")
+ denom = torch.zeros((1, 3), device="cuda")
+ acc = torch.zeros((1, 3, 512), device="cuda")
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=0,
+ num_candidates=4,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+ subset_output = torch.empty_like(acc)
+ subset_lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=subset_output,
+ lse=subset_lse,
+ )
+
+ output = torch.empty(1, 3, 512, device="cuda", dtype=torch.bfloat16)
+ merge_sparse_mla_subset_with_sink(
+ subset_output=subset_output,
+ subset_lse=subset_lse,
+ attn_sink=sink,
+ output=output,
+ )
+ valid_tokens = torch.ones(1, 4, device="cuda", dtype=torch.bool)
+ expected = _golden_sink_attention(q, gathered, valid_tokens, scale, sink)
+
+ torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("head_block_size", [1, 2, 4])
+def test_triton_fp8ds_paged_attention_with_sink_direct_matches_state_path(
+ head_block_size: int,
+) -> None:
+ torch.manual_seed(29)
+ block_size = 4
+ k_cache = torch.zeros(
+ 4,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ block_table = torch.tensor(
+ [[1, 0, 2, 3], [2, 3, 1, 0]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16)
+ sink = torch.linspace(-0.5, 0.5, 5, device="cuda")
+ scale = 0.0625
+
+ for token_idx in range(seq_lens.shape[0]):
+ start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item())
+ for gather_idx in range(int(gather_lens[token_idx].item())):
+ pos = start_pos + gather_idx
+ physical_block = int(block_table[token_idx, pos // block_size].item())
+ slot = physical_block * block_size + pos % block_size
+ _write_fp8_ds_mla_token(k_cache, slot, block_size)
+
+ max_score = torch.full((2, 5), float("-inf"), device="cuda")
+ denom = torch.zeros((2, 5), device="cuda")
+ acc = torch.zeros((2, 5, 512), device="cuda")
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=0,
+ num_candidates=5,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ head_block_size=1,
+ )
+ expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected)
+
+ actual = torch.empty_like(expected)
+ fp8ds_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=0,
+ num_candidates=5,
+ scale=scale,
+ attn_sink=sink,
+ output=actual,
+ head_block_size=head_block_size,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("head_block_size", [1, 2, 4])
+def test_triton_fp8ds_global_paged_attention_with_sink_direct_matches_state_path(
+ head_block_size: int,
+) -> None:
+ torch.manual_seed(31)
+ compressed_block_size = 4
+ swa_block_size = 4
+ compressed_cache = torch.zeros(
+ 4,
+ compressed_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ swa_cache = torch.zeros(
+ 4,
+ swa_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ slot_ids = torch.tensor(
+ [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda")
+ block_table = torch.tensor(
+ [[1, 0, 2, 3], [2, 3, 1, 0]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16)
+ sink = torch.linspace(-1.0, 1.0, 5, device="cuda")
+ scale = 0.0625
+
+ for slot in (0, 1, 3, 4, 7, 8):
+ _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size)
+ for token_idx in range(seq_lens.shape[0]):
+ start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item())
+ for gather_idx in range(int(gather_lens[token_idx].item())):
+ pos = start_pos + gather_idx
+ physical_block = int(block_table[token_idx, pos // swa_block_size].item())
+ slot = physical_block * swa_block_size + pos % swa_block_size
+ _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size)
+
+ comp_max = torch.full((2, 5), float("-inf"), device="cuda")
+ comp_denom = torch.zeros((2, 5), device="cuda")
+ comp_acc = torch.zeros((2, 5, 512), device="cuda")
+ swa_max = torch.full((2, 5), float("-inf"), device="cuda")
+ swa_denom = torch.zeros((2, 5), device="cuda")
+ swa_acc = torch.zeros((2, 5, 512), device="cuda")
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=compressed_cache,
+ slot_ids=slot_ids,
+ lens=topk_lens,
+ block_size=compressed_block_size,
+ candidate_offset=0,
+ scale=scale,
+ max_score=comp_max,
+ denom=comp_denom,
+ acc=comp_acc,
+ head_block_size=1,
+ )
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=swa_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=swa_block_size,
+ candidate_offset=0,
+ num_candidates=5,
+ scale=scale,
+ max_score=swa_max,
+ denom=swa_denom,
+ acc=swa_acc,
+ head_block_size=1,
+ )
+ expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ finish_two_sparse_mla_attention_states_with_sink(
+ comp_max,
+ comp_denom,
+ comp_acc,
+ swa_max,
+ swa_denom,
+ swa_acc,
+ sink,
+ expected,
+ )
+
+ actual = torch.empty_like(expected)
+ fp8ds_global_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ compressed_k_cache=compressed_cache,
+ slot_ids=slot_ids,
+ topk_lens=topk_lens,
+ compressed_block_size=compressed_block_size,
+ swa_k_cache=swa_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ swa_block_size=swa_block_size,
+ num_compressed_candidates=5,
+ num_swa_candidates=5,
+ scale=scale,
+ attn_sink=sink,
+ output=actual,
+ head_block_size=head_block_size,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_global_paged_decode_matches_dense_golden_long_offsets() -> None:
+ torch.manual_seed(71)
+ compressed_block_size = 4
+ swa_block_size = 4
+ compressed_cache = torch.zeros(
+ 8,
+ compressed_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ swa_cache = torch.zeros(
+ 8,
+ swa_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ compressed_slot_ids = torch.tensor(
+ [
+ [0, 5, -1, 11, 17, 2],
+ [23, -1, 8, 0, 19, 31],
+ [4, -1, -1, 6, 7, 12],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ topk_lens = torch.tensor([5, 6, 0], dtype=torch.int32, device="cuda")
+ seq_lens = torch.tensor([17, 30, 7], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([5, 6, 4], dtype=torch.int32, device="cuda")
+ block_table = torch.tensor(
+ [
+ [4, 0, 6, 1, 3, 5, 2, 7],
+ [7, 2, 4, 0, 6, 1, 5, 3],
+ [1, 3, 0, 2, 4, 5, 6, 7],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ q = torch.randn(3, 1, 8, 512, device="cuda", dtype=torch.bfloat16)
+ active_heads = 5
+ sink = torch.linspace(-0.75, 0.75, active_heads, device="cuda")
+ scale = 0.0625
+
+ compressed_kv = _materialize_global_fp8_ds_mla_slots(
+ compressed_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ )
+ swa_kv = _materialize_paged_fp8_ds_mla_window(
+ swa_cache,
+ seq_lens,
+ gather_lens,
+ block_table,
+ swa_block_size,
+ )
+ compressed_offsets = torch.arange(
+ compressed_slot_ids.shape[1],
+ device="cuda",
+ dtype=torch.int32,
+ )
+ swa_offsets = torch.arange(swa_kv.shape[1], device="cuda", dtype=torch.int32)
+ compressed_valid = (compressed_offsets[None, :] < topk_lens[:, None]) & (
+ compressed_slot_ids >= 0
+ )
+ swa_valid = swa_offsets[None, :] < gather_lens[:, None]
+ expected = _golden_sink_attention(
+ q[:, 0, :active_heads],
+ torch.cat([compressed_kv, swa_kv], dim=1),
+ torch.cat([compressed_valid, swa_valid], dim=1),
+ scale,
+ sink,
+ ).to(torch.bfloat16)
+
+ actual = torch.empty(3, active_heads, 512, device="cuda", dtype=torch.bfloat16)
+ fp8ds_global_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ compressed_k_cache=compressed_cache,
+ slot_ids=compressed_slot_ids,
+ topk_lens=topk_lens,
+ compressed_block_size=compressed_block_size,
+ swa_k_cache=swa_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ swa_block_size=swa_block_size,
+ num_compressed_candidates=compressed_slot_ids.shape[1],
+ num_swa_candidates=swa_kv.shape[1],
+ scale=scale,
+ attn_sink=sink,
+ output=actual,
+ head_block_size=2,
+ num_heads=active_heads,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_matmul_sparse_mla_attention_with_sink_matches_reference() -> None:
+ torch.manual_seed(41)
+ q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16)
+ valid_tokens = torch.tensor(
+ [
+ [True, True, False, True, False, True, True],
+ [False, True, True, False, True, False, False],
+ ],
+ dtype=torch.bool,
+ device="cuda",
+ )
+ sink = torch.linspace(-0.25, 0.25, 5, device="cuda")
+ scale = 0.0625
+
+ expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ sink_aware_reference_attention(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ sink,
+ expected,
+ )
+
+ actual = torch.empty_like(expected)
+ matmul_sparse_mla_attention_with_sink(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ sink,
+ actual,
+ num_heads=5,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_mtp_matmul_global_slots_decode_matches_dense_golden_long_offsets() -> None:
+ torch.manual_seed(73)
+ compressed_block_size = 4
+ swa_block_size = 4
+ compressed_cache = torch.zeros(
+ 8,
+ compressed_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ swa_cache = torch.zeros(
+ 8,
+ swa_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ compressed_slot_ids = torch.tensor(
+ [
+ [0, 5, -1, 11, 17, 2],
+ [23, -1, 8, 0, 19, 31],
+ [4, -1, -1, 6, 7, 12],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ topk_lens = torch.tensor([5, 6, 0], dtype=torch.int32, device="cuda")
+ swa_slot_ids = torch.tensor(
+ [
+ [3, 4, 13, 14, 15],
+ [28, 29, 30, 31, -1],
+ [7, 6, 5, -1, -1],
+ ],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ swa_lens = torch.tensor([5, 4, 3], dtype=torch.int32, device="cuda")
+ q = torch.randn(3, 1, 8, 512, device="cuda", dtype=torch.bfloat16)
+ active_heads = 5
+ sink = torch.linspace(-0.5, 0.5, active_heads, device="cuda")
+ scale = 0.0625
+
+ compressed_kv = _materialize_global_fp8_ds_mla_slots(
+ compressed_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ )
+ swa_kv = _materialize_global_fp8_ds_mla_slots(
+ swa_cache,
+ swa_slot_ids,
+ swa_block_size,
+ )
+ combined_kv = torch.empty(
+ 3,
+ compressed_slot_ids.shape[1] + swa_slot_ids.shape[1],
+ 512,
+ dtype=torch.bfloat16,
+ device="cuda",
+ )
+ dequantize_global_slots_k_cache(
+ combined_kv[:, : compressed_slot_ids.shape[1]],
+ compressed_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ )
+ dequantize_global_slots_k_cache(
+ combined_kv[:, compressed_slot_ids.shape[1] :],
+ swa_cache,
+ swa_slot_ids,
+ swa_block_size,
+ )
+ valid_tokens = torch.empty(
+ combined_kv.shape[:2],
+ dtype=torch.bool,
+ device="cuda",
+ )
+ build_combined_sparse_mla_decode_valid_mask(
+ valid_tokens,
+ compressed_slot_ids,
+ topk_lens,
+ swa_lens,
+ )
+
+ compressed_offsets = torch.arange(
+ compressed_slot_ids.shape[1],
+ device="cuda",
+ dtype=torch.int32,
+ )
+ swa_offsets = torch.arange(swa_slot_ids.shape[1], device="cuda", dtype=torch.int32)
+ compressed_valid = (compressed_offsets[None, :] < topk_lens[:, None]) & (
+ compressed_slot_ids >= 0
+ )
+ swa_valid = swa_offsets[None, :] < swa_lens[:, None]
+ expected_kv = torch.cat([compressed_kv, swa_kv], dim=1)
+ expected_valid = torch.cat([compressed_valid, swa_valid], dim=1)
+ expected = _golden_sink_attention(
+ q[:, 0, :active_heads],
+ expected_kv,
+ expected_valid,
+ scale,
+ sink,
+ ).to(torch.bfloat16)
+
+ torch.testing.assert_close(combined_kv.float(), expected_kv.float(), rtol=0, atol=0)
+ torch.testing.assert_close(valid_tokens, expected_valid)
+
+ actual = torch.empty_like(expected)
+ matmul_sparse_mla_attention_with_sink(
+ q,
+ combined_kv,
+ valid_tokens,
+ scale,
+ sink,
+ actual,
+ num_heads=active_heads,
+ value_block_size=512,
+ candidate_block_size=128,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_matmul_sparse_mla_attention_accepts_bf16_score_buffer() -> None:
+ torch.manual_seed(67)
+ q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16)
+ valid_tokens = torch.tensor(
+ [
+ [True, True, False, True, False, True, True],
+ [False, True, True, False, True, False, False],
+ ],
+ dtype=torch.bool,
+ device="cuda",
+ )
+ sink = torch.linspace(-0.25, 0.25, 5, device="cuda")
+ scale = 0.0625
+
+ expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16)
+ sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected)
+
+ actual = torch.empty_like(expected)
+ score_buffer = torch.empty(2, 5, 7, device="cuda", dtype=torch.bfloat16)
+ matmul_sparse_mla_attention_with_sink(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ sink,
+ actual,
+ num_heads=5,
+ score_buffer=score_buffer,
+ value_block_size=512,
+ candidate_block_size=128,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize(
+ ("candidate_block_size", "value_block_size"),
+ [(32, 128), (64, 128), (64, 256), (128, 512)],
+)
+def test_finish_materialized_scores_candidate_block_matches_reference(
+ candidate_block_size: int,
+ value_block_size: int,
+) -> None:
+ torch.manual_seed(61)
+ q = torch.randn(3, 1, 7, 512, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(3, 13, 512, device="cuda", dtype=torch.bfloat16)
+ valid_tokens = torch.tensor(
+ [
+ [
+ True,
+ True,
+ False,
+ True,
+ False,
+ True,
+ True,
+ False,
+ True,
+ True,
+ False,
+ True,
+ False,
+ ],
+ [
+ False,
+ True,
+ True,
+ False,
+ True,
+ False,
+ False,
+ True,
+ False,
+ True,
+ True,
+ False,
+ True,
+ ],
+ [
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ ],
+ ],
+ dtype=torch.bool,
+ device="cuda",
+ )
+ sink = torch.linspace(-0.25, 0.25, 7, device="cuda")
+ scale = 0.0625
+
+ expected = torch.empty(3, 7, 512, device="cuda", dtype=torch.bfloat16)
+ sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected)
+
+ scores = torch.bmm(q[:, 0].float(), kv.float().transpose(1, 2))
+ scores.mul_(scale)
+ actual = torch.empty_like(expected)
+ finish_materialized_sparse_mla_scores_with_sink(
+ scores,
+ kv,
+ valid_tokens,
+ sink,
+ actual,
+ num_heads=7,
+ value_block_size=value_block_size,
+ candidate_block_size=candidate_block_size,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("value_block_size", [128, 256])
+def test_finish_materialized_scores_value_block_matches_reference(
+ value_block_size: int,
+) -> None:
+ torch.manual_seed(53)
+ q = torch.randn(3, 1, 7, 512, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(3, 11, 512, device="cuda", dtype=torch.bfloat16)
+ valid_tokens = torch.tensor(
+ [
+ [True, True, False, True, False, True, True, False, True, True, False],
+ [False, True, True, False, True, False, False, True, False, True, True],
+ [
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ False,
+ ],
+ ],
+ dtype=torch.bool,
+ device="cuda",
+ )
+ sink = torch.linspace(-0.25, 0.25, 7, device="cuda")
+ scale = 0.0625
+
+ expected = torch.empty(3, 7, 512, device="cuda", dtype=torch.bfloat16)
+ sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected)
+
+ scores = torch.bmm(
+ q[:, 0].float(),
+ kv.float().transpose(1, 2),
+ )
+ scores.mul_(scale)
+ actual = torch.empty_like(expected)
+ finish_materialized_sparse_mla_scores_with_sink(
+ scores,
+ kv,
+ valid_tokens,
+ sink,
+ actual,
+ num_heads=7,
+ value_block_size=value_block_size,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_build_combined_sparse_mla_decode_valid_mask_matches_torch() -> None:
+ compressed_slot_ids = torch.tensor(
+ [
+ [7, 4, -1, 9, 11],
+ [2, -1, 3, 8, 10],
+ [-1, -1, -1, -1, -1],
+ ],
+ device="cuda",
+ dtype=torch.int32,
+ )
+ topk_lens = torch.tensor([4, 3, 0], device="cuda", dtype=torch.int32)
+ swa_lens = torch.tensor([3, 1, 0], device="cuda", dtype=torch.int32)
+ valid_tokens = torch.empty(3, 9, device="cuda", dtype=torch.bool)
+
+ build_combined_sparse_mla_decode_valid_mask(
+ valid_tokens,
+ compressed_slot_ids,
+ topk_lens,
+ swa_lens,
+ )
+
+ comp_offsets = torch.arange(5, device="cuda", dtype=torch.int32)
+ swa_offsets = torch.arange(4, device="cuda", dtype=torch.int32)
+ expected = torch.empty_like(valid_tokens)
+ expected[:, :5] = (comp_offsets[None, :] < topk_lens[:, None]) & (
+ compressed_slot_ids >= 0
+ )
+ expected[:, 5:] = swa_offsets[None, :] < swa_lens[:, None]
+
+ torch.testing.assert_close(valid_tokens, expected)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("num_heads", [8, 16, 32, 64])
+def test_triton_fp8ds_paged_attention_with_sink_supports_tp_local_heads(
+ num_heads: int,
+) -> None:
+ torch.manual_seed(37 + num_heads)
+ block_size = 4
+ k_cache = torch.zeros(
+ 4,
+ block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ block_table = torch.tensor(
+ [[1, 0, 2, 3], [2, 3, 1, 0]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16)
+ sink = torch.linspace(-0.5, 0.5, num_heads, device="cuda")
+ scale = 0.0625
+
+ for token_idx in range(seq_lens.shape[0]):
+ start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item())
+ for gather_idx in range(int(gather_lens[token_idx].item())):
+ pos = start_pos + gather_idx
+ physical_block = int(block_table[token_idx, pos // block_size].item())
+ slot = physical_block * block_size + pos % block_size
+ _write_fp8_ds_mla_token(k_cache, slot, block_size)
+
+ max_score = torch.full((2, num_heads), float("-inf"), device="cuda")
+ denom = torch.zeros((2, num_heads), device="cuda")
+ acc = torch.zeros((2, num_heads, 512), device="cuda")
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=0,
+ num_candidates=5,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ head_block_size=1,
+ )
+ expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16)
+ finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected)
+
+ actual = torch.empty_like(expected)
+ fp8ds_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ k_cache=k_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=block_size,
+ candidate_offset=0,
+ num_candidates=5,
+ scale=scale,
+ attn_sink=sink,
+ output=actual,
+ head_block_size=4,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+@pytest.mark.parametrize("num_heads", [8, 16, 32, 64])
+def test_triton_fp8ds_global_paged_attention_with_sink_supports_tp_local_heads(
+ num_heads: int,
+) -> None:
+ torch.manual_seed(41 + num_heads)
+ compressed_block_size = 4
+ swa_block_size = 4
+ compressed_cache = torch.zeros(
+ 4,
+ compressed_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ swa_cache = torch.zeros(
+ 4,
+ swa_block_size,
+ _TOKEN_DATA_SIZE + _SCALE_DIM,
+ dtype=torch.uint8,
+ device="cuda",
+ )
+ slot_ids = torch.tensor(
+ [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda")
+ block_table = torch.tensor(
+ [[1, 0, 2, 3], [2, 3, 1, 0]],
+ dtype=torch.int32,
+ device="cuda",
+ )
+ seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda")
+ gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda")
+ q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16)
+ sink = torch.linspace(-1.0, 1.0, num_heads, device="cuda")
+ scale = 0.0625
+
+ for slot in (0, 1, 3, 4, 7, 8):
+ _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size)
+ for token_idx in range(seq_lens.shape[0]):
+ start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item())
+ for gather_idx in range(int(gather_lens[token_idx].item())):
+ pos = start_pos + gather_idx
+ physical_block = int(block_table[token_idx, pos // swa_block_size].item())
+ slot = physical_block * swa_block_size + pos % swa_block_size
+ _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size)
+
+ comp_max = torch.full((2, num_heads), float("-inf"), device="cuda")
+ comp_denom = torch.zeros((2, num_heads), device="cuda")
+ comp_acc = torch.zeros((2, num_heads, 512), device="cuda")
+ swa_max = torch.full((2, num_heads), float("-inf"), device="cuda")
+ swa_denom = torch.zeros((2, num_heads), device="cuda")
+ swa_acc = torch.zeros((2, num_heads, 512), device="cuda")
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=compressed_cache,
+ slot_ids=slot_ids,
+ lens=topk_lens,
+ block_size=compressed_block_size,
+ candidate_offset=0,
+ scale=scale,
+ max_score=comp_max,
+ denom=comp_denom,
+ acc=comp_acc,
+ head_block_size=1,
+ )
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=swa_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ block_size=swa_block_size,
+ candidate_offset=0,
+ num_candidates=5,
+ scale=scale,
+ max_score=swa_max,
+ denom=swa_denom,
+ acc=swa_acc,
+ head_block_size=1,
+ )
+ expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16)
+ finish_two_sparse_mla_attention_states_with_sink(
+ comp_max,
+ comp_denom,
+ comp_acc,
+ swa_max,
+ swa_denom,
+ swa_acc,
+ sink,
+ expected,
+ )
+
+ actual = torch.empty_like(expected)
+ fp8ds_global_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ compressed_k_cache=compressed_cache,
+ slot_ids=slot_ids,
+ topk_lens=topk_lens,
+ compressed_block_size=compressed_block_size,
+ swa_k_cache=swa_cache,
+ seq_lens=seq_lens,
+ gather_lens=gather_lens,
+ block_table=block_table,
+ swa_block_size=swa_block_size,
+ num_compressed_candidates=5,
+ num_swa_candidates=5,
+ scale=scale,
+ attn_sink=sink,
+ output=actual,
+ head_block_size=4,
+ )
+
+ torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only")
+def test_triton_indexed_bf16_prefill_chunks_match_reference() -> None:
+ torch.manual_seed(17)
+ q = torch.randn(5, 5, 16, device="cuda", dtype=torch.bfloat16)
+ q_active = q[:, :3]
+ kv = torch.randn(2, 7, 16, device="cuda", dtype=torch.bfloat16)
+ kv_flat = kv.reshape(-1, q.shape[-1])
+ combined_indices = torch.tensor(
+ [
+ [0, 3, -1, 5, 3, 1],
+ [4, -1, 2, 2, 1, 8],
+ [-1, -1, -1, -1, -1, -1],
+ [8, 0, 9, -1, 7, 4],
+ [13, 12, 0, 12, -1, 3],
+ ],
+ dtype=torch.int64,
+ device="cuda",
+ )
+ combined_lens = torch.tensor([5, 4, 0, 6, 5], dtype=torch.int32, device="cuda")
+ sink = torch.tensor([-0.5, 1.0, 0.25], dtype=torch.float32, device="cuda")
+ scale = 0.375
+ output = torch.empty_like(q_active)
+
+ for token_start in (0, 2, 4):
+ token_end = min(token_start + 2, q.shape[0])
+ q_chunk = q[token_start:token_end]
+ indices_chunk = combined_indices[token_start:token_end]
+ lens_chunk = combined_lens[token_start:token_end]
+ max_score = torch.full(
+ (q_chunk.shape[0], q_active.shape[1]),
+ float("-inf"),
+ device="cuda",
+ )
+ denom = torch.zeros_like(max_score)
+ acc = torch.zeros(
+ q_chunk.shape[0],
+ q_active.shape[1],
+ q_chunk.shape[-1],
+ device="cuda",
+ dtype=torch.float32,
+ )
+ for index_start in (0, 3):
+ index_end = min(index_start + 3, combined_indices.shape[-1])
+ accumulate_indexed_sparse_mla_attention_chunk(
+ q=q_chunk,
+ kv_flat=kv_flat,
+ indices=indices_chunk[:, index_start:index_end],
+ lens=lens_chunk,
+ candidate_offset=index_start,
+ scale=scale,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ )
+ subset_output = torch.empty_like(acc)
+ subset_lse = torch.empty_like(max_score)
+ finish_gathered_sparse_mla_attention(
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ output=subset_output,
+ lse=subset_lse,
+ )
+ merge_sparse_mla_subset_with_sink(
+ subset_output=subset_output,
+ subset_lse=subset_lse,
+ attn_sink=sink,
+ output=output[token_start:token_end],
+ )
+
+ expected = torch.empty_like(q_active)
+ reference_sparse_mla_prefill(
+ q=q_active,
+ kv=kv,
+ combined_indices=combined_indices,
+ combined_lens=combined_lens,
+ scale=scale,
+ attn_sink=sink,
+ output=expected,
+ topk_chunk_size=3,
+ query_chunk_size=2,
+ )
+
+ torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2)
+
+
+@pytest.mark.parametrize(
+ ("topk_chunk_size", "query_chunk_size"),
+ [(1, 1), (2, 3), (5, 2)],
+)
+def test_reference_sparse_mla_prefill_matches_dense_golden(
+ topk_chunk_size: int,
+ query_chunk_size: int,
+) -> None:
+ torch.manual_seed(4)
+ scale = 0.375
+ q = torch.randn(4, 2, 3)
+ kv = torch.randn(2, 5, 3)
+ combined_indices = torch.tensor(
+ [
+ [0, 3, -1, 5, 3],
+ [4, -1, 2, 2, 1],
+ [-1, -1, -1, -1, -1],
+ [8, 0, 9, -1, 7],
+ ],
+ dtype=torch.int64,
+ )
+ combined_lens = torch.tensor([4, 3, 0, 5], dtype=torch.int32)
+ sink = torch.tensor([-0.5, 1.0])
+ output = torch.empty_like(q)
+
+ reference_sparse_mla_prefill(
+ q=q,
+ kv=kv,
+ combined_indices=combined_indices,
+ combined_lens=combined_lens,
+ scale=scale,
+ attn_sink=sink,
+ output=output,
+ topk_chunk_size=topk_chunk_size,
+ query_chunk_size=query_chunk_size,
+ )
+
+ kv_flat = kv.reshape(-1, q.shape[-1])
+ offsets = torch.arange(combined_indices.shape[-1])
+ valid_tokens = (offsets[None, :] < combined_lens[:, None]) & (combined_indices >= 0)
+ safe_indices = torch.where(
+ valid_tokens,
+ combined_indices,
+ torch.zeros((), dtype=combined_indices.dtype),
+ ).long()
+ gathered_kv = kv_flat[safe_indices]
+ expected = _golden_sink_attention(q, gathered_kv, valid_tokens, scale, sink)
+
+ torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6)
+
+
+@pytest.mark.parametrize("chunk_size", [1, 2, 5])
+def test_chunked_reference_accumulation_matches_one_shot(chunk_size: int) -> None:
+ torch.manual_seed(3)
+ scale = 0.3
+ q = torch.randn(3, 2, 4)
+ kv = torch.randn(3, 9, 4)
+ valid_tokens = torch.tensor(
+ [
+ [True, False, True, True, False, False, True, False, True],
+ [False, False, False, False, False, False, False, False, False],
+ [True, True, True, False, True, False, True, True, False],
+ ],
+ dtype=torch.bool,
+ )
+ output, lse = _chunked_no_sink_attention(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ chunk_size,
+ )
+ expected_output, expected_lse = _golden_no_sink_attention(
+ q,
+ kv,
+ valid_tokens,
+ scale,
+ )
+
+ torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6)
+ torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6)
+
+
+def test_triton_sparse_mla_path_allows_cudagraph_support_by_default(
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1")
+ monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False)
+
+ mla_spec = MLAAttentionSpec(
+ block_size=256,
+ num_kv_heads=1,
+ head_size=512,
+ dtype=torch.uint8,
+ cache_dtype_str="fp8_ds_mla",
+ alignment=576,
+ compress_ratio=4,
+ model_version="deepseek_v4",
+ )
+ swa_spec = SlidingWindowMLASpec(
+ block_size=64,
+ num_kv_heads=1,
+ head_size=512,
+ dtype=torch.uint8,
+ sliding_window=128,
+ cache_dtype_str="fp8_ds_mla",
+ alignment=576,
+ model_version="deepseek_v4",
+ )
+
+ assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is (
+ AttentionCGSupport.UNIFORM_BATCH
+ )
+ assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is (
+ AttentionCGSupport.UNIFORM_BATCH
+ )
+
+ vllm_config = SimpleNamespace(
+ compilation_config=SimpleNamespace(
+ mode=CompilationMode.VLLM_COMPILE,
+ compile_sizes=[1, 2],
+ compile_ranges_endpoints=[8192],
+ cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE,
+ cudagraph_capture_sizes=[1, 2, 4],
+ max_cudagraph_capture_size=4,
+ )
+ )
+ disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config)
+
+ assert vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
+ assert vllm_config.compilation_config.compile_sizes == [1, 2]
+ assert vllm_config.compilation_config.compile_ranges_endpoints == [8192]
+ assert (
+ vllm_config.compilation_config.cudagraph_mode
+ == CUDAGraphMode.FULL_AND_PIECEWISE
+ )
+ assert vllm_config.compilation_config.cudagraph_capture_sizes == [1, 2, 4]
+ assert vllm_config.compilation_config.max_cudagraph_capture_size == 4
+
+
+def test_triton_sparse_mla_path_can_disable_cudagraphs(monkeypatch) -> None:
+ monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1")
+ monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "0")
+
+ mla_spec = MLAAttentionSpec(
+ block_size=256,
+ num_kv_heads=1,
+ head_size=512,
+ dtype=torch.uint8,
+ cache_dtype_str="fp8_ds_mla",
+ alignment=576,
+ compress_ratio=4,
+ model_version="deepseek_v4",
+ )
+ swa_spec = SlidingWindowMLASpec(
+ block_size=64,
+ num_kv_heads=1,
+ head_size=512,
+ dtype=torch.uint8,
+ sliding_window=128,
+ cache_dtype_str="fp8_ds_mla",
+ alignment=576,
+ model_version="deepseek_v4",
+ )
+
+ assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is (
+ AttentionCGSupport.NEVER
+ )
+ assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is (
+ AttentionCGSupport.NEVER
+ )
+
+ vllm_config = SimpleNamespace(
+ compilation_config=SimpleNamespace(
+ mode=CompilationMode.VLLM_COMPILE,
+ compile_sizes=[1, 2],
+ compile_ranges_endpoints=[8192],
+ cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE,
+ cudagraph_capture_sizes=[1, 2, 4],
+ max_cudagraph_capture_size=4,
+ )
+ )
+ disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config)
+
+ assert vllm_config.compilation_config.mode == CompilationMode.NONE
+ assert vllm_config.compilation_config.compile_sizes == []
+ assert vllm_config.compilation_config.compile_ranges_endpoints == []
+ assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
+ assert vllm_config.compilation_config.cudagraph_capture_sizes == []
+ assert vllm_config.compilation_config.max_cudagraph_capture_size == 0
+
+
+def test_triton_sparse_mla_path_disables_cudagraphs_for_mtp(
+ monkeypatch,
+) -> None:
+ monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1")
+ monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False)
+
+ mla_spec = MLAAttentionSpec(
+ block_size=256,
+ num_kv_heads=1,
+ head_size=512,
+ dtype=torch.uint8,
+ cache_dtype_str="fp8_ds_mla",
+ alignment=576,
+ compress_ratio=4,
+ model_version="deepseek_v4",
+ )
+ swa_spec = SlidingWindowMLASpec(
+ block_size=64,
+ num_kv_heads=1,
+ head_size=512,
+ dtype=torch.uint8,
+ sliding_window=128,
+ cache_dtype_str="fp8_ds_mla",
+ alignment=576,
+ model_version="deepseek_v4",
+ )
+ vllm_config = SimpleNamespace(
+ speculative_config=SimpleNamespace(
+ method="mtp",
+ num_speculative_tokens=2,
+ ),
+ compilation_config=SimpleNamespace(
+ mode=CompilationMode.VLLM_COMPILE,
+ compile_sizes=[1, 2],
+ compile_ranges_endpoints=[8192],
+ cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE,
+ cudagraph_capture_sizes=[1, 2, 4],
+ max_cudagraph_capture_size=4,
+ ),
+ )
+
+ assert (
+ FlashMLASparseMetadataBuilder.get_cudagraph_support(
+ vllm_config,
+ mla_spec,
+ )
+ is AttentionCGSupport.NEVER
+ )
+ assert (
+ DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(
+ vllm_config,
+ swa_spec,
+ )
+ is AttentionCGSupport.NEVER
+ )
+
+ disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config)
+
+ assert vllm_config.compilation_config.mode == CompilationMode.NONE
+ assert vllm_config.compilation_config.compile_sizes == []
+ assert vllm_config.compilation_config.compile_ranges_endpoints == []
+ assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
+ assert vllm_config.compilation_config.cudagraph_capture_sizes == []
+ assert vllm_config.compilation_config.max_cudagraph_capture_size == 0
diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py
new file mode 100644
index 000000000000..3fae48146f2d
--- /dev/null
+++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py
@@ -0,0 +1,245 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+import torch
+
+import vllm.utils.deep_gemm as deep_gemm_utils
+from vllm.model_executor.layers.sparse_attn_indexer import (
+ _decode_logits_width,
+ _decode_topk_logits_width,
+ _sparse_indexer_requires_deep_gemm,
+)
+from vllm.platforms import current_platform
+from vllm.utils.math_utils import cdiv
+from vllm.v1.attention.backends.mla import indexer as mla_indexer
+
+
+def test_decode_logits_width_uses_active_context_bound():
+ assert _decode_logits_width(262144, 1024) == 1024
+ assert _decode_logits_width(4096, 8192) == 4096
+ assert _decode_logits_width(4096, 0) == 4096
+ assert _decode_logits_width(0, 1024) == 0
+
+
+def test_decode_topk_logits_width_keeps_topk_kernel_width():
+ assert _decode_topk_logits_width(262144, 1024, 512) == 1024
+ assert _decode_topk_logits_width(262144, 128, 512) == 512
+ assert _decode_topk_logits_width(300, 128, 512) == 300
+ assert _decode_topk_logits_width(0, 128, 512) == 0
+
+
+def test_sm120_sparse_indexer_does_not_require_deep_gemm(monkeypatch):
+ monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
+ monkeypatch.setattr(
+ current_platform,
+ "is_device_capability_family",
+ lambda capability: capability == 120,
+ )
+
+ assert _sparse_indexer_requires_deep_gemm() is False
+
+
+def test_non_sm120_cuda_sparse_indexer_still_requires_deep_gemm(monkeypatch):
+ monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
+ monkeypatch.setattr(
+ current_platform,
+ "is_device_capability_family",
+ lambda capability: False,
+ )
+
+ assert _sparse_indexer_requires_deep_gemm() is True
+
+
+def test_sm120_paged_mqa_metadata_uses_backend_impl(monkeypatch):
+ monkeypatch.setattr(
+ current_platform,
+ "is_device_capability_family",
+ lambda capability: capability == 120,
+ )
+ lazy_init_calls = []
+ monkeypatch.setattr(
+ deep_gemm_utils, "_lazy_init", lambda: lazy_init_calls.append(1)
+ )
+ expected = torch.tensor([[1, 2]], dtype=torch.int32)
+
+ def fake_deep_gemm_metadata(context_lens, block_size, num_sms):
+ assert context_lens.shape == (2, 1)
+ assert block_size == 256
+ assert num_sms == 4
+ return expected
+
+ monkeypatch.setattr(
+ deep_gemm_utils,
+ "_get_paged_mqa_logits_metadata_impl",
+ fake_deep_gemm_metadata,
+ )
+ context_lens = torch.tensor([[1], [3]], dtype=torch.int32)
+
+ metadata = deep_gemm_utils.get_paged_mqa_logits_metadata(
+ context_lens, block_size=256, num_sms=4
+ )
+
+ assert metadata is expected
+ assert lazy_init_calls == [1]
+
+
+def test_sm120_mla_indexer_skips_deep_gemm_scheduler_metadata(monkeypatch):
+ monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
+ monkeypatch.setattr(
+ current_platform,
+ "is_device_capability_family",
+ lambda capability: capability == 120,
+ )
+ monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True)
+
+ assert not mla_indexer._uses_deep_gemm_scheduler_metadata()
+
+
+def test_cuda_mla_indexer_uses_deep_gemm_scheduler_metadata_off_sm12x(monkeypatch):
+ monkeypatch.setattr(current_platform, "is_cuda", lambda: True)
+ monkeypatch.setattr(
+ current_platform,
+ "is_device_capability_family",
+ lambda capability: False,
+ )
+ monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True)
+
+ assert mla_indexer._uses_deep_gemm_scheduler_metadata()
+
+
+def test_sm120_fp8_mqa_fallbacks_do_not_initialize_deep_gemm(monkeypatch):
+ monkeypatch.setattr(
+ current_platform,
+ "is_device_capability_family",
+ lambda capability: capability == 120,
+ )
+
+ def fail_lazy_init():
+ raise AssertionError("SM120 FP8 MQA should not initialize DeepGEMM")
+
+ monkeypatch.setattr(deep_gemm_utils, "_lazy_init", fail_lazy_init)
+
+ mqa_result = torch.empty(1)
+ paged_result = torch.empty(1)
+ calls = []
+
+ def fake_mqa_fallback(*args, **kwargs):
+ calls.append("mqa")
+ return mqa_result
+
+ def fake_paged_fallback(*args, **kwargs):
+ calls.append("paged")
+ return paged_result
+
+ monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_sm12x", fake_mqa_fallback)
+ monkeypatch.setattr(
+ deep_gemm_utils, "_fp8_paged_mqa_logits_sm12x", fake_paged_fallback
+ )
+
+ assert (
+ deep_gemm_utils.fp8_fp4_mqa_logits(
+ (torch.empty(1, 1, 1), None),
+ (torch.empty(1, 1), torch.empty(1)),
+ torch.empty(1, 1),
+ torch.empty(1, dtype=torch.int32),
+ torch.empty(1, dtype=torch.int32),
+ clean_logits=False,
+ )
+ is mqa_result
+ )
+ assert (
+ deep_gemm_utils.fp8_fp4_paged_mqa_logits(
+ (torch.empty(1, 1, 1, 1), None),
+ torch.empty(1, 1, 1, 5, dtype=torch.uint8),
+ torch.empty(1, 1),
+ torch.empty(1, 1, dtype=torch.int32),
+ torch.empty(1, 1, dtype=torch.int32),
+ torch.empty(1, dtype=torch.int32),
+ max_model_len=1,
+ clean_logits=False,
+ )
+ is paged_result
+ )
+ assert calls == ["mqa", "paged"]
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(7)
+ batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32
+ block_size, max_model_len, num_blocks = 4, 64, 16
+ active_max_len = 13
+ topk_tokens = 6
+ monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None)
+ monkeypatch.setattr(deep_gemm_utils, "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", 7)
+
+ q = torch.randn(
+ batch_size,
+ next_n,
+ num_heads,
+ head_dim,
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ q_fp8 = q.to(torch.float8_e4m3fn).contiguous()
+ kv = torch.randn(
+ num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16
+ )
+ kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0
+ kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn)
+ fused_kv = torch.empty(
+ num_blocks,
+ block_size,
+ 1,
+ head_dim + 4,
+ device="cuda",
+ dtype=torch.uint8,
+ )
+ fused_kv[..., :head_dim] = kv_fp8.view(torch.uint8)
+ fused_kv[..., head_dim:] = kv_scale.contiguous().view(torch.uint8)
+
+ weights = torch.randn(
+ batch_size * next_n, num_heads, device="cuda", dtype=torch.float32
+ )
+ context_lens = torch.tensor(
+ [[5, active_max_len], [9, 12]], device="cuda", dtype=torch.int32
+ )
+ block_tables = (
+ torch.arange(
+ batch_size * cdiv(max_model_len, block_size),
+ device="cuda",
+ dtype=torch.int32,
+ ).reshape(batch_size, -1)
+ % num_blocks
+ )
+
+ full_width_topk = torch.empty(
+ batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32
+ )
+ truncated_width_topk = torch.empty_like(full_width_topk)
+
+ assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices(
+ (q_fp8, None),
+ fused_kv,
+ weights,
+ context_lens,
+ block_tables,
+ max_model_len,
+ full_width_topk,
+ )
+ assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices(
+ (q_fp8, None),
+ fused_kv,
+ weights,
+ context_lens,
+ block_tables,
+ active_max_len,
+ truncated_width_topk,
+ )
+
+ torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0)
diff --git a/tests/v1/attention/test_sparse_attn_indexer.py b/tests/v1/attention/test_sparse_attn_indexer.py
new file mode 100644
index 000000000000..eb09cf058d8f
--- /dev/null
+++ b/tests/v1/attention/test_sparse_attn_indexer.py
@@ -0,0 +1,40 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import pytest
+
+from vllm.model_executor.layers.sparse_attn_indexer import (
+ SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH,
+ SM120_SHORT_ROW_TOPK_MAX_WIDTH,
+ _should_use_sm120_short_row_topk_decode,
+)
+
+
+@pytest.mark.parametrize(
+ ("topk_tokens", "logits_width", "num_rows", "is_cuda_sm120", "expected"),
+ [
+ (512, SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, 32, True, True),
+ (512, 8192, 16, True, True),
+ (512, 8192, 32, True, True),
+ (512, 12288, 32, True, False),
+ (512, SM120_SHORT_ROW_TOPK_MAX_WIDTH, 1, True, False),
+ (512, 4096, 1, False, False),
+ (2048, 4096, 1, True, False),
+ ],
+)
+def test_sm120_short_row_topk_decode_selector(
+ topk_tokens: int,
+ logits_width: int,
+ num_rows: int,
+ is_cuda_sm120: bool,
+ expected: bool,
+) -> None:
+ assert (
+ _should_use_sm120_short_row_topk_decode(
+ topk_tokens,
+ logits_width,
+ num_rows,
+ is_cuda_sm120,
+ )
+ is expected
+ )
diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py
index 22acc748d24b..8becd568b680 100644
--- a/tests/v1/attention/test_sparse_mla_backends.py
+++ b/tests/v1/attention/test_sparse_mla_backends.py
@@ -8,6 +8,7 @@
import pytest
import torch
+import vllm.utils.deep_gemm as deep_gemm_utils
from tests.v1.attention.test_mla_backends import (
BATCH_SPECS,
BatchSpec,
@@ -42,9 +43,16 @@
FlashMLASparseBackend,
triton_convert_req_index_to_global_index,
)
-from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks
+from vllm.v1.attention.backends.mla.indexer import (
+ sparse_indexer_max_logits_bytes,
+ split_indexer_prefill_chunks,
+)
from vllm.v1.attention.backends.utils import split_prefill_chunks
from vllm.v1.attention.ops import flashmla
+from vllm.v1.attention.ops.deepseek_v4_ops import (
+ combine_topk_swa_indices,
+ compute_global_topk_indices_and_lens,
+)
SPARSE_BACKEND_BATCH_SPECS = {
name: BATCH_SPECS[name]
@@ -67,6 +75,487 @@
DEVICE_TYPE = current_platform.device_type
+def _make_packed_fp8_indexer_cache(
+ kv_fp8: torch.Tensor,
+ kv_scale: torch.Tensor,
+) -> torch.Tensor:
+ num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape
+ assert num_kv_heads == 1
+ kv_scale_bytes = kv_scale.contiguous().view(torch.uint8).reshape(
+ num_blocks, block_size, num_kv_heads, -1
+ )
+ scale_bytes = kv_scale_bytes.shape[-1]
+ fused_kv = torch.empty(
+ num_blocks,
+ block_size,
+ head_dim + scale_bytes,
+ device=kv_fp8.device,
+ dtype=torch.uint8,
+ )
+ fused_kv_blocks = fused_kv.view(num_blocks, -1)
+ value_end = block_size * head_dim
+ scale_end = value_end + block_size * scale_bytes
+ fused_kv_blocks[:, :value_end] = kv_fp8.view(torch.uint8).reshape(
+ num_blocks, -1
+ )
+ fused_kv_blocks[:, value_end:scale_end] = kv_scale_bytes.reshape(
+ num_blocks, -1
+ )
+ return fused_kv
+
+
+def test_sm120_fp8_mqa_logits_chunk_sizes_cap_large_scores():
+ assert deep_gemm_utils._fp8_mqa_logits_head_chunk_size(128, 128, 32) == 8
+ assert deep_gemm_utils._fp8_mqa_logits_head_chunk_size(8192, 8192, 32) == 1
+ assert deep_gemm_utils._fp8_mqa_logits_k_chunk_size(128, 128, 8) == 128
+ assert deep_gemm_utils._fp8_mqa_logits_k_chunk_size(8192, 8192, 1) == 2048
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_tf32_hc_prenorm_gemm_fallback_matches_split_abi(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(0)
+ num_tokens, out_features, hidden_size = 7, 12, 64
+ x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16)
+ fn = torch.randn(out_features, hidden_size, device="cuda", dtype=torch.float32)
+
+ out = torch.empty(num_tokens, out_features, device="cuda", dtype=torch.float32)
+ sqrsum = torch.empty(num_tokens, device="cuda", dtype=torch.float32)
+ deep_gemm_utils._tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split=1)
+
+ expected_out = x.float() @ fn.T
+ expected_sqrsum = x.float().square().sum(dim=-1)
+ torch.testing.assert_close(out, expected_out, rtol=0, atol=0)
+ torch.testing.assert_close(sqrsum, expected_sqrsum, rtol=0, atol=0)
+
+ split_out = torch.empty(3, num_tokens, out_features, device="cuda")
+ split_sqrsum = torch.empty(3, num_tokens, device="cuda")
+ deep_gemm_utils._tf32_hc_prenorm_gemm_torch(
+ x, fn, split_out, split_sqrsum, num_split=3
+ )
+ torch.testing.assert_close(split_out.sum(dim=0), expected_out, rtol=0, atol=0)
+ torch.testing.assert_close(split_sqrsum.sum(dim=0), expected_sqrsum, rtol=0, atol=0)
+
+ monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None)
+ monkeypatch.setattr(deep_gemm_utils, "_tf32_hc_prenorm_gemm_impl", None)
+ wrapper_out = torch.empty_like(split_out)
+ wrapper_sqrsum = torch.empty_like(split_sqrsum)
+ deep_gemm_utils.tf32_hc_prenorm_gemm(
+ x, fn, wrapper_out, wrapper_sqrsum, num_split=3
+ )
+ torch.testing.assert_close(
+ wrapper_out.sum(dim=0), expected_out, rtol=2e-2, atol=2e-2
+ )
+ torch.testing.assert_close(
+ wrapper_sqrsum.sum(dim=0), expected_sqrsum, rtol=1e-4, atol=1e-4
+ )
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_paged_mqa_logits_fallback_matches_reference(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(1)
+ batch_size, next_n, num_heads, head_dim = 2, 2, 4, 32
+ block_size, max_model_len, num_blocks = 4, 12, 4
+
+ q = torch.randn(
+ batch_size,
+ next_n,
+ num_heads,
+ head_dim,
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ q_fp8 = q.to(torch.float8_e4m3fn)
+ kv = torch.randn(
+ num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16
+ )
+ kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0
+ kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn)
+ fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale)
+
+ weights = torch.randn(
+ batch_size * next_n, num_heads, device="cuda", dtype=torch.float32
+ )
+ context_lens = torch.tensor([[3, 6], [7, 11]], device="cuda", dtype=torch.int32)
+ block_tables = torch.tensor(
+ [[0, 1, 2], [1, 2, 3]], device="cuda", dtype=torch.int32
+ )
+ expected = torch.full(
+ (batch_size * next_n, max_model_len),
+ float("-inf"),
+ device="cuda",
+ dtype=torch.float32,
+ )
+ kv_dequant = kv_fp8.float() * kv_scale
+ for batch_idx in range(batch_size):
+ for next_idx in range(next_n):
+ row = batch_idx * next_n + next_idx
+ for token_idx in range(int(context_lens[batch_idx, next_idx].item())):
+ block = int(block_tables[batch_idx, token_idx // block_size].item())
+ offset = token_idx % block_size
+ score = (
+ q_fp8[batch_idx, next_idx].float() * kv_dequant[block, offset, 0]
+ ).sum(dim=1)
+ expected[row, token_idx] = (score.relu() * weights[row]).sum()
+
+ monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None)
+ monkeypatch.setattr(deep_gemm_utils, "_fp8_fp4_paged_mqa_logits_impl", None)
+
+ def fail_torch_path(*args, **kwargs):
+ raise AssertionError("torch paged fallback should not be used")
+
+ monkeypatch.setattr(deep_gemm_utils, "_fp8_paged_mqa_logits_torch", fail_torch_path)
+ actual = deep_gemm_utils.fp8_fp4_paged_mqa_logits(
+ (q_fp8.contiguous(), None),
+ fused_kv,
+ weights,
+ context_lens,
+ block_tables,
+ schedule_metadata=torch.empty(0, device="cuda", dtype=torch.int32),
+ max_model_len=max_model_len,
+ clean_logits=False,
+ )
+ torch.testing.assert_close(actual, expected, rtol=0, atol=1e-5)
+
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ fp8_paged_mqa_logits_triton,
+ )
+
+ triton_actual = fp8_paged_mqa_logits_triton(
+ q_fp8.contiguous(), fused_kv, weights, context_lens, block_tables, max_model_len
+ )
+ assert torch.equal(torch.isneginf(triton_actual), torch.isneginf(expected))
+ finite = torch.isfinite(expected)
+ assert (triton_actual[finite] - expected[finite]).abs().max() < 2e-2
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_paged_mqa_rowwise_logits_matches_reference():
+ torch.manual_seed(11)
+ batch_size, next_n, num_heads, head_dim = 2, 1, 8, 64
+ block_size, max_model_len, num_blocks = 4, 18, 8
+
+ q = torch.randn(
+ batch_size,
+ next_n,
+ num_heads,
+ head_dim,
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ q_fp8 = q.to(torch.float8_e4m3fn).contiguous()
+ kv = torch.randn(
+ num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16
+ )
+ kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0
+ kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn)
+ fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale)
+
+ weights = torch.randn(
+ batch_size * next_n, num_heads, device="cuda", dtype=torch.float32
+ )
+ context_lens = torch.tensor([[7], [17]], device="cuda", dtype=torch.int32)
+ block_tables = (
+ torch.arange(
+ batch_size * cdiv(max_model_len, block_size),
+ device="cuda",
+ dtype=torch.int32,
+ ).reshape(batch_size, -1)
+ % num_blocks
+ )
+
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ fp8_paged_mqa_logits_rowwise_triton,
+ )
+
+ actual = fp8_paged_mqa_logits_rowwise_triton(
+ q_fp8, fused_kv, weights, context_lens, block_tables, max_model_len
+ )
+ expected = deep_gemm_utils._fp8_paged_mqa_logits_torch(
+ (q_fp8, None), fused_kv, weights, context_lens, block_tables, max_model_len
+ )
+
+ assert torch.equal(torch.isneginf(actual), torch.isneginf(expected))
+ finite = torch.isfinite(expected)
+ assert (actual[finite] - expected[finite]).abs().max() < 2e-2
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_paged_mqa_topk_indices_streams_chunks(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(3)
+ batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32
+ block_size, max_model_len, num_blocks = 4, 20, 8
+ topk_tokens = 5
+ monkeypatch.setattr(
+ deep_gemm_utils,
+ "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE",
+ 7,
+ )
+ monkeypatch.setattr(
+ torch,
+ "cat",
+ lambda *args, **kwargs: (_ for _ in ()).throw(
+ AssertionError("paged MQA top-k should reuse candidate buffers")
+ ),
+ )
+
+ q = torch.randn(
+ batch_size,
+ next_n,
+ num_heads,
+ head_dim,
+ device="cuda",
+ dtype=torch.bfloat16,
+ )
+ q_fp8 = q.to(torch.float8_e4m3fn)
+ kv = torch.randn(
+ num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16
+ )
+ kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0
+ kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn)
+ fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale)
+
+ weights = torch.randn(
+ batch_size * next_n, num_heads, device="cuda", dtype=torch.float32
+ )
+ context_lens = torch.tensor([[3, 11], [17, 20]], device="cuda", dtype=torch.int32)
+ block_tables = (
+ torch.arange(
+ batch_size * cdiv(max_model_len, block_size),
+ device="cuda",
+ dtype=torch.int32,
+ ).reshape(batch_size, -1)
+ % num_blocks
+ )
+ topk_indices = torch.empty(
+ batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32
+ )
+
+ assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices(
+ (q_fp8.contiguous(), None),
+ fused_kv,
+ weights,
+ context_lens,
+ block_tables,
+ max_model_len,
+ topk_indices,
+ )
+
+ logits = deep_gemm_utils._fp8_paged_mqa_logits_torch(
+ (q_fp8.contiguous(), None),
+ fused_kv,
+ weights,
+ context_lens,
+ block_tables,
+ max_model_len,
+ )
+ expected = torch.full_like(topk_indices, -1)
+ flat_context_lens = context_lens.reshape(-1)
+ for row in range(batch_size * next_n):
+ valid_count = int(flat_context_lens[row].item())
+ row_topk = min(topk_tokens, valid_count)
+ if row_topk > 0:
+ expected[row, :row_topk] = (
+ logits[row].topk(row_topk).indices.to(torch.int32)
+ )
+
+ for row in range(batch_size * next_n):
+ row_topk = min(topk_tokens, int(flat_context_lens[row].item()))
+ assert set(topk_indices[row, :row_topk].tolist()) == set(
+ expected[row, :row_topk].tolist()
+ )
+ assert torch.all(topk_indices[row, row_topk:] == -1)
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_mqa_logits_torch_path_streams_head_chunks(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(0)
+ seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32
+ monkeypatch.setattr(
+ deep_gemm_utils,
+ "_SM120_MQA_LOGITS_MAX_SCORE_BYTES",
+ seq_len * 5 * 4,
+ )
+
+ q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16)
+ weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32)
+ cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3
+ cu_seqlen_ke = torch.minimum(
+ torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4,
+ torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32),
+ )
+
+ q_fp8 = q.to(torch.float8_e4m3fn)
+ kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4)
+ kv_scale = (kv_amax / 448.0).squeeze(1).contiguous()
+ kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn)
+
+ logits = deep_gemm_utils._fp8_mqa_logits_torch(
+ (q_fp8, None),
+ (kv_fp8, kv_scale),
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ clean_logits=True,
+ )
+
+ kv_dequant = kv_fp8.float() * kv_scale[:, None]
+ score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant)
+ ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0)
+ offsets = torch.arange(seq_len_kv, device="cuda")
+ valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & (
+ offsets[None, :] < cu_seqlen_ke[:, None]
+ )
+ ref_logits = ref_logits.masked_fill(~valid, float("-inf"))
+
+ assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits))
+ finite = torch.isfinite(ref_logits)
+ assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_mqa_logits_wrapper_uses_triton_when_deepgemm_missing(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(2)
+ seq_len, seq_len_kv, num_heads, head_dim = 5, 13, 8, 32
+
+ q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16)
+ weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32)
+ cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3
+ cu_seqlen_ke = torch.minimum(
+ cu_seqlen_ks + 6,
+ torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32),
+ )
+
+ q_fp8 = q.to(torch.float8_e4m3fn)
+ kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4)
+ kv_scale = (kv_amax / 448.0).squeeze(1).contiguous()
+ kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn)
+
+ kv_dequant = kv_fp8.float() * kv_scale[:, None]
+ score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant)
+ expected = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0)
+ offsets = torch.arange(seq_len_kv, device="cuda")
+ valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & (
+ offsets[None, :] < cu_seqlen_ke[:, None]
+ )
+ expected = expected.masked_fill(~valid, float("-inf"))
+
+ monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None)
+ monkeypatch.setattr(deep_gemm_utils, "_fp8_fp4_mqa_logits_impl", None)
+
+ def fail_torch_path(*args, **kwargs):
+ raise AssertionError("torch fallback should not be used")
+
+ monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_torch", fail_torch_path)
+ actual = deep_gemm_utils.fp8_fp4_mqa_logits(
+ (q_fp8, None),
+ (kv_fp8, kv_scale),
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ clean_logits=True,
+ )
+
+ assert torch.equal(torch.isneginf(actual), torch.isneginf(expected))
+ finite = torch.isfinite(expected)
+ assert (actual[finite] - expected[finite]).abs().max() < 2e-2
+
+
+@pytest.mark.skipif(
+ not current_platform.is_device_capability_family(120), reason="SM120 only"
+)
+def test_sm120_fp8_mqa_logits_topk_streams_k_chunks(
+ monkeypatch: pytest.MonkeyPatch,
+):
+ torch.manual_seed(1)
+ seq_len, seq_len_kv, num_heads, head_dim = 11, 23, 16, 32
+ topk_tokens = 5
+ monkeypatch.setattr(
+ deep_gemm_utils,
+ "_SM120_MQA_LOGITS_MAX_SCORE_BYTES",
+ seq_len * 5 * 4,
+ )
+ monkeypatch.setattr(
+ torch,
+ "cat",
+ lambda *args, **kwargs: (_ for _ in ()).throw(
+ AssertionError("MQA top-k should reuse candidate buffers")
+ ),
+ )
+
+ q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16)
+ kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16)
+ weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32)
+ cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 4
+ valid_lens = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 7
+ cu_seqlen_ke = torch.minimum(
+ cu_seqlen_ks + valid_lens,
+ torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32),
+ )
+
+ q_fp8 = q.to(torch.float8_e4m3fn)
+ kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4)
+ kv_scale = (kv_amax / 448.0).squeeze(1).contiguous()
+ kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn)
+
+ topk_indices = deep_gemm_utils._fp8_mqa_logits_topk_torch(
+ (q_fp8, None),
+ (kv_fp8, kv_scale),
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ topk_tokens,
+ )
+
+ logits = deep_gemm_utils._fp8_mqa_logits_torch(
+ (q_fp8, None),
+ (kv_fp8, kv_scale),
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ clean_logits=True,
+ )
+ expected = torch.full_like(topk_indices, -1)
+ for row in range(seq_len):
+ valid_count = int((cu_seqlen_ke[row] - cu_seqlen_ks[row]).item())
+ row_topk = min(topk_tokens, valid_count)
+ if row_topk > 0:
+ expected[row, :row_topk] = (
+ logits[row].topk(row_topk).indices.to(torch.int32)
+ )
+
+ for row in range(seq_len):
+ valid_count = int((cu_seqlen_ke[row] - cu_seqlen_ks[row]).item())
+ row_topk = min(topk_tokens, valid_count)
+ assert set(topk_indices[row, :row_topk].tolist()) == set(
+ expected[row, :row_topk].tolist()
+ )
+ assert torch.all(topk_indices[row, row_topk:] == -1)
+
+
def _float_to_e8m0_truncate(f: float) -> float:
"""Simulate SM100's float -> e8m0 -> bf16 scale conversion.
e8m0 format only stores the exponent (power of 2).
@@ -218,8 +707,14 @@ def test_sparse_backend_decode_correctness(
if not ok:
pytest.skip(reason)
elif backend_cls == FlashInferMLASparseBackend:
- if not current_platform.has_device_capability(100):
- pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher")
+ capability = current_platform.get_device_capability()
+ if capability is None or not backend_cls.supports_compute_capability(
+ capability
+ ):
+ pytest.skip(
+ "FlashInferMLASparseBackend does not support "
+ f"{capability} on this platform"
+ )
batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name]
use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla"
@@ -781,6 +1276,119 @@ def test_split_indexer_prefill_chunks(
assert out == expected
+def test_sparse_indexer_max_logits_bytes_uses_sm12x_safe_default(monkeypatch):
+ monkeypatch.delenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", raising=False)
+
+ assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 256 * 1024 * 1024
+ assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 512 * 1024 * 1024
+
+
+def test_sparse_indexer_max_logits_bytes_honors_env_override(monkeypatch):
+ monkeypatch.setenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "384")
+
+ assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 384 * 1024 * 1024
+ assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 384 * 1024 * 1024
+
+
+def test_compute_global_topk_indices_supports_in_place_output():
+ device = torch.device(DEVICE_TYPE)
+ block_size = 4
+ topk_indices = torch.tensor(
+ [[0, 3, 4, -1], [2, 5, -1, -1], [1, 7, -1, -1]],
+ dtype=torch.int32,
+ device=device,
+ )
+ token_to_req = torch.tensor([0, 1, 1], dtype=torch.int32, device=device)
+ block_table = torch.tensor(
+ [[10, 11, 12], [20, 21, 22]], dtype=torch.int32, device=device
+ )
+ is_valid = torch.tensor([True, True, False], device=device)
+
+ expected_indices = torch.tensor(
+ [
+ [40, 43, 44, -1],
+ [82, 85, -1, -1],
+ [-1, -1, -1, -1],
+ ],
+ dtype=torch.int32,
+ device=device,
+ )
+ expected_lens = torch.tensor([3, 2, 0], dtype=torch.int32, device=device)
+
+ out, lens = compute_global_topk_indices_and_lens(
+ topk_indices,
+ token_to_req,
+ block_table,
+ block_size,
+ is_valid,
+ )
+ torch.testing.assert_close(out, expected_indices, rtol=0, atol=0)
+ torch.testing.assert_close(lens, expected_lens, rtol=0, atol=0)
+
+ in_place = topk_indices.clone()
+ provided_lens = torch.empty(3, dtype=torch.int32, device=device)
+ out, lens = compute_global_topk_indices_and_lens(
+ in_place,
+ token_to_req,
+ block_table,
+ block_size,
+ is_valid,
+ global_topk_indices=in_place,
+ topk_lens=provided_lens,
+ )
+ assert out is in_place
+ assert lens is provided_lens
+ torch.testing.assert_close(in_place, expected_indices, rtol=0, atol=0)
+ torch.testing.assert_close(provided_lens, expected_lens, rtol=0, atol=0)
+
+
+def test_combine_topk_swa_indices_supports_workspace_outputs():
+ device = torch.device(DEVICE_TYPE)
+ num_tokens = 6
+ topk = 4
+ window_size = 8
+ topk_indices = (
+ torch.arange(num_tokens * topk, dtype=torch.int32, device=device)
+ .reshape(num_tokens, topk)
+ .remainder(5)
+ )
+ query_start_loc = torch.tensor([0, num_tokens], dtype=torch.int32, device=device)
+ seq_lens = torch.tensor([20], dtype=torch.int32, device=device)
+ gather_lens = torch.tensor([8], dtype=torch.int32, device=device)
+
+ expected_indices, expected_lens = combine_topk_swa_indices(
+ topk_indices,
+ query_start_loc,
+ seq_lens,
+ gather_lens,
+ window_size,
+ 4,
+ topk,
+ 16,
+ 12,
+ )
+ workspace_indices = torch.empty_like(expected_indices)
+ workspace_lens = torch.empty_like(expected_lens)
+ actual_indices, actual_lens = combine_topk_swa_indices(
+ topk_indices,
+ query_start_loc,
+ seq_lens,
+ gather_lens,
+ window_size,
+ 4,
+ topk,
+ 16,
+ 12,
+ combined_indices=workspace_indices,
+ combined_lens=workspace_lens,
+ )
+
+ assert actual_indices.data_ptr() == workspace_indices.data_ptr()
+ assert actual_lens.data_ptr() == workspace_lens.data_ptr()
+ torch.testing.assert_close(actual_indices, expected_indices, rtol=0, atol=0)
+ torch.testing.assert_close(actual_lens, expected_lens, rtol=0, atol=0)
+
+
def test_split_indexer_prefill_chunks_single_request_overflow():
"""Test that single request exceeding budget is sub-chunked on query dim."""
seq_lens = torch.tensor([1000, 50])
diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py
new file mode 100644
index 000000000000..9745ea169cc2
--- /dev/null
+++ b/tests/v1/attention/test_sparse_mla_env.py
@@ -0,0 +1,96 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+
+import os
+from collections.abc import Iterator
+from contextlib import contextmanager
+
+import torch
+
+from vllm.envs import environment_variables
+from vllm.v1.attention.backends.mla.sparse_mla_env import (
+ is_triton_sparse_mla_enabled,
+ triton_sparse_mla_cudagraphs_allowed,
+ triton_sparse_mla_head_block_size,
+ triton_sparse_mla_query_chunk_size,
+ triton_sparse_mla_topk_chunk_size,
+)
+
+_SPARSE_MLA_ENV_NAMES = (
+ "VLLM_TRITON_MLA_SPARSE",
+ "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE",
+ "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE",
+ "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH",
+ "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE",
+)
+
+
+@contextmanager
+def _patched_sparse_mla_env(**updates: str) -> Iterator[None]:
+ previous = {name: os.environ.get(name) for name in _SPARSE_MLA_ENV_NAMES}
+ try:
+ for name in _SPARSE_MLA_ENV_NAMES:
+ os.environ.pop(name, None)
+ os.environ.update(updates)
+ yield
+ finally:
+ for name, value in previous.items():
+ if value is None:
+ os.environ.pop(name, None)
+ else:
+ os.environ[name] = value
+
+
+def test_triton_sparse_mla_env_uses_new_name() -> None:
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="0"):
+ assert not is_triton_sparse_mla_enabled(torch.device("cpu"))
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="1"):
+ assert is_triton_sparse_mla_enabled(torch.device("cpu"))
+
+
+def test_sparse_mla_cudagraph_env_defaults_to_allowed() -> None:
+ with _patched_sparse_mla_env():
+ assert triton_sparse_mla_cudagraphs_allowed()
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="0"):
+ assert not triton_sparse_mla_cudagraphs_allowed()
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="1"):
+ assert triton_sparse_mla_cudagraphs_allowed()
+
+
+def test_sparse_mla_head_block_env_accepts_supported_values() -> None:
+ with _patched_sparse_mla_env():
+ assert triton_sparse_mla_head_block_size() is None
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="1"):
+ assert triton_sparse_mla_head_block_size() == 1
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="2"):
+ assert triton_sparse_mla_head_block_size() == 2
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"):
+ assert triton_sparse_mla_head_block_size() == 4
+
+
+def test_sparse_mla_head_block_env_ignores_invalid_values() -> None:
+ for value in ("0", "3", "invalid"):
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=value):
+ assert triton_sparse_mla_head_block_size() is None
+
+
+def test_sparse_mla_head_block_env_is_registered_with_vllm_envs() -> None:
+ assert "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" in environment_variables
+
+ with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"):
+ assert environment_variables["VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE"]() == 4
+
+
+def test_sparse_mla_chunk_env_defaults_invalid_values() -> None:
+ with _patched_sparse_mla_env(
+ VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE="invalid",
+ VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE="-7",
+ ):
+ assert triton_sparse_mla_topk_chunk_size() == 512
+ assert triton_sparse_mla_query_chunk_size() == 1
diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py
index c35c38911a1a..c9f9fdd58fe2 100644
--- a/tests/v1/core/test_prefix_caching.py
+++ b/tests/v1/core/test_prefix_caching.py
@@ -36,6 +36,8 @@
KVCacheConfig,
KVCacheGroupSpec,
MambaSpec,
+ MLAAttentionSpec,
+ SlidingWindowMLASpec,
SlidingWindowSpec,
)
@@ -2573,6 +2575,215 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt():
)
+def test_deepseek_v4_mla_keeps_hybrid_aligned_prompt_blocks_after_decode():
+ hash_block_size = 2
+ full_block_size = 8
+ swa_block_size = 2
+ prompt_tokens = 35
+ chunk_tokens = 4 * full_block_size
+ expected_hit_tokens = (
+ (prompt_tokens - 1) // full_block_size * full_block_size
+ )
+
+ config = KVCacheConfig(
+ num_blocks=70,
+ kv_cache_tensors=[],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer_full"],
+ MLAAttentionSpec(
+ block_size=full_block_size,
+ num_kv_heads=1,
+ head_size=1,
+ dtype=torch.uint8,
+ cache_dtype_str="fp8_ds_mla",
+ model_version="deepseek_v4",
+ ),
+ ),
+ KVCacheGroupSpec(
+ ["layer_swa_mla_0"],
+ SlidingWindowMLASpec(
+ block_size=swa_block_size,
+ num_kv_heads=1,
+ head_size=1,
+ dtype=torch.uint8,
+ sliding_window=2 * swa_block_size,
+ cache_dtype_str="fp8_ds_mla",
+ model_version="deepseek_v4",
+ ),
+ ),
+ KVCacheGroupSpec(
+ ["layer_swa_mla_1"],
+ SlidingWindowMLASpec(
+ block_size=swa_block_size,
+ num_kv_heads=1,
+ head_size=1,
+ dtype=torch.uint8,
+ sliding_window=2 * swa_block_size,
+ cache_dtype_str="fp8_ds_mla",
+ model_version="deepseek_v4",
+ ),
+ ),
+ KVCacheGroupSpec(
+ ["layer_swa_mla_compressor_state"],
+ SlidingWindowMLASpec(
+ block_size=swa_block_size,
+ num_kv_heads=1,
+ head_size=1,
+ dtype=torch.float32,
+ sliding_window=2 * swa_block_size,
+ ),
+ ),
+ ],
+ )
+ manager = KVCacheManager(
+ config,
+ max_model_len=128,
+ max_num_batched_tokens=chunk_tokens,
+ enable_caching=True,
+ hash_block_size=hash_block_size,
+ )
+
+ def run_request(request: Request, num_decode_tokens: int) -> int:
+ computed_blocks, num_computed_tokens = manager.get_computed_blocks(request)
+ computed_so_far = num_computed_tokens
+ remaining_prompt_tokens = request.num_prompt_tokens - num_computed_tokens
+ first_chunk = True
+ while remaining_prompt_tokens > 0:
+ num_new_tokens = min(chunk_tokens, remaining_prompt_tokens)
+ allocated = manager.allocate_slots(
+ request,
+ num_new_tokens,
+ num_computed_tokens if first_chunk else 0,
+ computed_blocks if first_chunk else None,
+ )
+ assert allocated is not None
+ computed_so_far += num_new_tokens
+ request.num_computed_tokens = computed_so_far
+ remaining_prompt_tokens -= num_new_tokens
+ first_chunk = False
+
+ for i in range(num_decode_tokens):
+ request.append_output_token_ids(10_000 + i)
+ allocated = manager.allocate_slots(request, 1)
+ assert allocated is not None
+ computed_so_far += 1
+ request.num_computed_tokens = computed_so_far
+ return num_computed_tokens
+
+ prompt_a = list(range(prompt_tokens))
+ req_a = make_request("a", prompt_a, hash_block_size, sha256)
+ assert run_request(req_a, num_decode_tokens=0) == 0
+ manager.free(req_a)
+
+ warm_a = make_request("warm_a", prompt_a, hash_block_size, sha256)
+ assert run_request(warm_a, num_decode_tokens=8) == expected_hit_tokens
+ assert manager.get_num_common_prefix_blocks("warm_a")[0] >= (
+ expected_hit_tokens // full_block_size
+ )
+ manager.free(warm_a)
+
+ pressure_blocks = manager.block_pool.get_new_blocks(
+ manager.block_pool.get_num_free_blocks()
+ )
+ manager.block_pool.free_blocks(reversed(pressure_blocks))
+
+ req_a_again = make_request("a_again", prompt_a, hash_block_size, sha256)
+ _, num_computed_tokens = manager.get_computed_blocks(req_a_again)
+ assert num_computed_tokens == expected_hit_tokens
+
+
+def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission():
+ block_size = 8
+ prompt_tokens = 4 * block_size + 3
+ protected_blocks_per_prompt = (prompt_tokens - 1) // block_size
+ num_prompts = 10
+ num_blocks = 80
+ manager = KVCacheManager(
+ KVCacheConfig(
+ num_blocks=num_blocks,
+ kv_cache_tensors=[],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer_full"],
+ MLAAttentionSpec(
+ block_size=block_size,
+ num_kv_heads=1,
+ head_size=1,
+ dtype=torch.uint8,
+ cache_dtype_str="fp8_ds_mla",
+ model_version="deepseek_v4",
+ ),
+ )
+ ],
+ ),
+ max_model_len=512,
+ max_num_batched_tokens=128,
+ enable_caching=True,
+ hash_block_size=block_size,
+ )
+ mla_manager = manager.coordinator.single_type_managers[0]
+
+ for i in range(num_prompts):
+ prompt = list(range(i * 1000, i * 1000 + prompt_tokens))
+ req = make_request(f"protected_{i}", prompt, block_size, sha256)
+ assert manager.allocate_slots(req, prompt_tokens) is not None
+ req.num_computed_tokens = prompt_tokens
+ manager.free(req)
+
+ assert len(mla_manager._protected_prompt_block_ids) == (
+ num_prompts * protected_blocks_per_prompt
+ )
+ assert manager.block_pool.get_num_free_blocks() < 64
+
+ long_req = make_request(
+ "long",
+ list(range(100_000, 100_000 + 64 * block_size)),
+ block_size,
+ sha256,
+ )
+ assert (
+ manager.allocate_slots(long_req, block_size, full_sequence_must_fit=True)
+ is not None
+ )
+
+
+def test_reset_prefix_cache_releases_deepseek_v4_mla_protected_blocks():
+ block_size = 8
+ prompt_tokens = 4 * block_size + 3
+ manager = KVCacheManager(
+ KVCacheConfig(
+ num_blocks=32,
+ kv_cache_tensors=[],
+ kv_cache_groups=[
+ KVCacheGroupSpec(
+ ["layer_full"],
+ MLAAttentionSpec(
+ block_size=block_size,
+ num_kv_heads=1,
+ head_size=1,
+ dtype=torch.uint8,
+ cache_dtype_str="fp8_ds_mla",
+ model_version="deepseek_v4",
+ ),
+ )
+ ],
+ ),
+ max_model_len=512,
+ max_num_batched_tokens=128,
+ enable_caching=True,
+ hash_block_size=block_size,
+ )
+
+ req = make_request("protected", list(range(prompt_tokens)), block_size, sha256)
+ assert manager.allocate_slots(req, prompt_tokens) is not None
+ req.num_computed_tokens = prompt_tokens
+ manager.free(req)
+
+ assert manager.coordinator.single_type_managers[0]._protected_prompt_block_ids
+ assert manager.reset_prefix_cache()
+
+
def test_can_fit_full_sequence_full_attention_still_gates_oversized():
"""The cap only loosens the SWA group; a prompt that exceeds the
full-attention pool capacity must still be rejected."""
diff --git a/tests/v1/executor/test_ray_utils.py b/tests/v1/executor/test_ray_utils.py
index 8da9d5459e73..a83b513baca8 100644
--- a/tests/v1/executor/test_ray_utils.py
+++ b/tests/v1/executor/test_ray_utils.py
@@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+from types import SimpleNamespace
+
import numpy as np
+from vllm.v1.executor import ray_utils
from vllm.v1.executor.ray_utils import detach_zero_copy_from_model_runner_output
from vllm.v1.outputs import LogprobsLists, LogprobsTensors, ModelRunnerOutput
@@ -52,3 +55,46 @@ def test_detach_zero_copy_from_model_runner_output_copies_only_numpy_views():
assert detached_logprobs.sampled_token_ranks.flags.writeable
assert detached_logprobs.cu_num_generated_tokens is cu_num_generated_tokens
assert output.prompt_logprobs_dict["req-0"] is prompt_logprobs
+
+
+def test_cluster_device_warning_uses_ray_cluster_resources(monkeypatch):
+ warnings = []
+
+ monkeypatch.setattr(
+ ray_utils,
+ "ray",
+ SimpleNamespace(cluster_resources=lambda: {"GPU": 2}),
+ )
+ monkeypatch.setattr(
+ ray_utils.logger,
+ "warning",
+ lambda *args, **kwargs: warnings.append(args),
+ )
+
+ ray_utils._warn_if_insufficient_cluster_devices(
+ SimpleNamespace(world_size=2), "GPU"
+ )
+
+ assert warnings == []
+
+
+def test_cluster_device_warning_reports_cluster_shortage(monkeypatch):
+ warnings = []
+
+ monkeypatch.setattr(
+ ray_utils,
+ "ray",
+ SimpleNamespace(cluster_resources=lambda: {"GPU": 1}),
+ )
+ monkeypatch.setattr(
+ ray_utils.logger,
+ "warning",
+ lambda *args, **kwargs: warnings.append(args),
+ )
+
+ ray_utils._warn_if_insufficient_cluster_devices(
+ SimpleNamespace(world_size=2), "GPU"
+ )
+
+ assert len(warnings) == 1
+ assert "distributed world size" in warnings[0][0]
diff --git a/tools/compare_vllm_http_logprobs_oracle.py b/tools/compare_vllm_http_logprobs_oracle.py
new file mode 100755
index 000000000000..85a3d6aade55
--- /dev/null
+++ b/tools/compare_vllm_http_logprobs_oracle.py
@@ -0,0 +1,431 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Compare current vLLM HTTP logprobs with a captured oracle bundle.
+
+The expected oracle format is a directory with request_*.json and matching
+response_*.json files produced from /v1/completions with token-id logprob keys.
+"""
+
+from __future__ import annotations
+
+import argparse
+import json
+import sys
+import urllib.error
+import urllib.request
+from pathlib import Path
+from typing import Any
+
+Json = dict[str, Any]
+
+
+class TokenNormalizer:
+ """Normalize token-id placeholders to decoded token strings."""
+
+ def __init__(self, decode_token_id):
+ self._decode_token_id = decode_token_id
+
+ def token_key(self, token: Any) -> str:
+ key = _token_key(token)
+ if not key.startswith("token_id:"):
+ return key
+ try:
+ token_id = int(key.removeprefix("token_id:"))
+ except ValueError:
+ return key
+ return self._decode_token_id(token_id)
+
+
+def _token_key(token: Any) -> str:
+ if isinstance(token, str):
+ return token
+ if isinstance(token, int):
+ return f"token_id:{token}"
+ return str(token)
+
+
+def _choice(response: Json) -> Json:
+ choices = response.get("choices")
+ if not isinstance(choices, list) or not choices:
+ raise ValueError("response has no choices[0]")
+ choice = choices[0]
+ if not isinstance(choice, dict):
+ raise ValueError("response choices[0] is not an object")
+ return choice
+
+
+def _normalize_token(token: Any, normalizer: TokenNormalizer | None) -> str:
+ if normalizer is None:
+ return _token_key(token)
+ return normalizer.token_key(token)
+
+
+def _generated_tokens(
+ response: Json, normalizer: TokenNormalizer | None = None
+) -> list[str]:
+ choice = _choice(response)
+ logprobs = choice.get("logprobs") or {}
+ tokens = logprobs.get("tokens")
+ if isinstance(tokens, list) and tokens:
+ return [_normalize_token(token, normalizer) for token in tokens]
+ token_ids = choice.get("token_ids")
+ if isinstance(token_ids, list):
+ return [_normalize_token(token, normalizer) for token in token_ids]
+ return []
+
+
+def _prompt_token_ids(response: Json) -> list[int] | None:
+ token_ids = _choice(response).get("prompt_token_ids")
+ if not isinstance(token_ids, list):
+ return None
+ return [int(token) for token in token_ids]
+
+
+def _token_logprobs(response: Json) -> list[float | None]:
+ logprobs = _choice(response).get("logprobs") or {}
+ values = logprobs.get("token_logprobs")
+ if not isinstance(values, list):
+ return []
+ out: list[float | None] = []
+ for value in values:
+ out.append(float(value) if value is not None else None)
+ return out
+
+
+def _top_logprobs(
+ response: Json, normalizer: TokenNormalizer | None = None
+) -> list[dict[str, float]]:
+ logprobs = _choice(response).get("logprobs") or {}
+ values = logprobs.get("top_logprobs")
+ if not isinstance(values, list):
+ return []
+ out: list[dict[str, float]] = []
+ for step in values:
+ if not isinstance(step, dict):
+ out.append({})
+ continue
+ out.append(
+ {
+ _normalize_token(token, normalizer): float(logprob)
+ for token, logprob in step.items()
+ }
+ )
+ return out
+
+
+def _first_mismatch(oracle_tokens: list[str], actual_tokens: list[str]) -> Json | None:
+ for step, (oracle_token, actual_token) in enumerate(
+ zip(oracle_tokens, actual_tokens)
+ ):
+ if oracle_token != actual_token:
+ return {"step": step, "oracle": oracle_token, "actual": actual_token}
+ if len(oracle_tokens) != len(actual_tokens):
+ step = min(len(oracle_tokens), len(actual_tokens))
+ return {
+ "step": step,
+ "oracle": oracle_tokens[step] if step < len(oracle_tokens) else None,
+ "actual": actual_tokens[step] if step < len(actual_tokens) else None,
+ }
+ return None
+
+
+def _matching_prefix(oracle_tokens: list[str], actual_tokens: list[str]) -> int:
+ count = 0
+ for oracle_token, actual_token in zip(oracle_tokens, actual_tokens):
+ if oracle_token != actual_token:
+ break
+ count += 1
+ return count
+
+
+def _top_keys(top_logprobs: dict[str, float], top_n: int) -> list[str]:
+ return list(top_logprobs.keys())[:top_n]
+
+
+def _mean(values: list[float]) -> float | None:
+ return sum(values) / len(values) if values else None
+
+
+def _max(values: list[float]) -> float | None:
+ return max(values) if values else None
+
+
+def compare_response(
+ case_name: str,
+ oracle_response: Json,
+ actual_response: Json,
+ *,
+ top_n: int = 50,
+ normalizer: TokenNormalizer | None = None,
+) -> Json:
+ oracle_tokens = _generated_tokens(oracle_response, normalizer)
+ actual_tokens = _generated_tokens(actual_response, normalizer)
+ oracle_top = _top_logprobs(oracle_response, normalizer)
+ actual_top = _top_logprobs(actual_response, normalizer)
+ oracle_token_logprobs = _token_logprobs(oracle_response)
+ actual_token_logprobs = _token_logprobs(actual_response)
+
+ steps = min(
+ len(oracle_tokens), len(actual_tokens), len(oracle_top), len(actual_top)
+ )
+ top1_matches = 0
+ topk_overlaps: list[float] = []
+ common_logprob_errors: list[float] = []
+ chosen_token_logprob_errors: list[float] = []
+
+ for step in range(steps):
+ oracle_keys = _top_keys(oracle_top[step], top_n)
+ actual_keys = _top_keys(actual_top[step], top_n)
+ if oracle_keys and actual_keys and oracle_keys[0] == actual_keys[0]:
+ top1_matches += 1
+
+ oracle_set = set(oracle_keys)
+ actual_set = set(actual_keys)
+ if oracle_set:
+ topk_overlaps.append(len(oracle_set & actual_set) / len(oracle_set))
+ for token in oracle_set & actual_set:
+ common_logprob_errors.append(
+ abs(oracle_top[step][token] - actual_top[step][token])
+ )
+
+ if (
+ oracle_tokens[step] == actual_tokens[step]
+ and step < len(oracle_token_logprobs)
+ and step < len(actual_token_logprobs)
+ and oracle_token_logprobs[step] is not None
+ and actual_token_logprobs[step] is not None
+ ):
+ chosen_token_logprob_errors.append(
+ abs(oracle_token_logprobs[step] - actual_token_logprobs[step])
+ )
+
+ oracle_prompt_ids = _prompt_token_ids(oracle_response)
+ actual_prompt_ids = _prompt_token_ids(actual_response)
+ prompt_ids_match = (
+ None
+ if oracle_prompt_ids is None or actual_prompt_ids is None
+ else oracle_prompt_ids == actual_prompt_ids
+ )
+
+ first_mismatch = _first_mismatch(oracle_tokens, actual_tokens)
+ return {
+ "case": case_name,
+ "tokens_match": first_mismatch is None,
+ "prompt_token_ids_match": prompt_ids_match,
+ "first_token_mismatch": first_mismatch,
+ "matching_prefix_tokens": _matching_prefix(oracle_tokens, actual_tokens),
+ "oracle_token_count": len(oracle_tokens),
+ "actual_token_count": len(actual_tokens),
+ "compared_steps": steps,
+ "top1_matches": top1_matches,
+ "top1_match_rate": top1_matches / steps if steps else None,
+ "topk_overlap_mean": _mean(topk_overlaps),
+ "topk_overlap_min": min(topk_overlaps) if topk_overlaps else None,
+ "max_common_logprob_abs_error": _max(common_logprob_errors),
+ "mean_common_logprob_abs_error": _mean(common_logprob_errors),
+ "max_chosen_token_logprob_abs_error": _max(chosen_token_logprob_errors),
+ "mean_chosen_token_logprob_abs_error": _mean(chosen_token_logprob_errors),
+ }
+
+
+def _load_json(path: Path) -> Json:
+ with path.open(encoding="utf-8") as f:
+ data = json.load(f)
+ if not isinstance(data, dict):
+ raise ValueError(f"{path} is not a JSON object")
+ return data
+
+
+def load_oracle_cases(oracle_dir: Path) -> list[tuple[str, Json, Json]]:
+ cases: list[tuple[str, Json, Json]] = []
+ request_paths = sorted(oracle_dir.glob("request_*.json"))
+ if not request_paths:
+ raise ValueError(f"{oracle_dir} has no request_*.json files")
+ for request_path in request_paths:
+ suffix = request_path.stem.removeprefix("request_")
+ response_path = oracle_dir / f"response_{suffix}.json"
+ if not response_path.exists():
+ raise ValueError(f"missing {response_path.name} for {request_path.name}")
+ cases.append((suffix, _load_json(request_path), _load_json(response_path)))
+ return cases
+
+
+def post_completion(base_url: str, payload: Json, timeout: float) -> Json:
+ url = f"{base_url.rstrip('/')}/v1/completions"
+ encoded = json.dumps(payload).encode("utf-8")
+ request = urllib.request.Request(
+ url,
+ data=encoded,
+ headers={"Content-Type": "application/json"},
+ method="POST",
+ )
+ try:
+ with urllib.request.urlopen(request, timeout=timeout) as response:
+ body = response.read().decode("utf-8")
+ except urllib.error.HTTPError as exc:
+ body = exc.read().decode("utf-8", errors="replace")
+ raise RuntimeError(f"HTTP {exc.code} from {url}: {body}") from exc
+ data = json.loads(body)
+ if not isinstance(data, dict):
+ raise ValueError(f"{url} returned non-object JSON")
+ return data
+
+
+def load_token_normalizer(
+ tokenizer: str,
+ *,
+ tokenizer_mode: str,
+ trust_remote_code: bool,
+) -> TokenNormalizer:
+ from vllm.tokenizers import get_tokenizer
+
+ hf_tokenizer = get_tokenizer(
+ tokenizer,
+ tokenizer_mode=tokenizer_mode,
+ trust_remote_code=trust_remote_code,
+ )
+
+ def decode_token_id(token_id: int) -> str:
+ return hf_tokenizer.decode([token_id])
+
+ return TokenNormalizer(decode_token_id)
+
+
+def summarize_reports(reports: list[Json]) -> Json:
+ return {
+ "case_count": len(reports),
+ "all_tokens_match": all(report["tokens_match"] for report in reports),
+ "all_prompt_token_ids_match": all(
+ report["prompt_token_ids_match"] is not False for report in reports
+ ),
+ "min_top1_match_rate": min(
+ (
+ report["top1_match_rate"]
+ for report in reports
+ if report["top1_match_rate"] is not None
+ ),
+ default=None,
+ ),
+ "min_topk_overlap_mean": min(
+ (
+ report["topk_overlap_mean"]
+ for report in reports
+ if report["topk_overlap_mean"] is not None
+ ),
+ default=None,
+ ),
+ "max_common_logprob_abs_error": max(
+ (
+ report["max_common_logprob_abs_error"]
+ for report in reports
+ if report["max_common_logprob_abs_error"] is not None
+ ),
+ default=None,
+ ),
+ "max_chosen_token_logprob_abs_error": max(
+ (
+ report["max_chosen_token_logprob_abs_error"]
+ for report in reports
+ if report["max_chosen_token_logprob_abs_error"] is not None
+ ),
+ default=None,
+ ),
+ }
+
+
+def _fails_thresholds(
+ summary: Json, reports: list[Json], args: argparse.Namespace
+) -> bool:
+ failed = False
+ if args.strict_tokens and not summary["all_tokens_match"]:
+ failed = True
+ if args.strict_prompt_token_ids and not summary["all_prompt_token_ids_match"]:
+ failed = True
+ if args.min_top1_match_rate is not None:
+ value = summary["min_top1_match_rate"]
+ failed = failed or value is None or value < args.min_top1_match_rate
+ if args.min_topk_overlap_mean is not None:
+ value = summary["min_topk_overlap_mean"]
+ failed = failed or value is None or value < args.min_topk_overlap_mean
+ if args.max_common_logprob_abs_error is not None:
+ value = summary["max_common_logprob_abs_error"]
+ failed = failed or value is None or value > args.max_common_logprob_abs_error
+ if args.max_chosen_token_logprob_abs_error is not None:
+ value = summary["max_chosen_token_logprob_abs_error"]
+ failed = (
+ failed or value is None or value > args.max_chosen_token_logprob_abs_error
+ )
+ if args.fail_on_first_mismatch:
+ failed = failed or any(report["first_token_mismatch"] for report in reports)
+ return failed
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument("--oracle-dir", required=True, type=Path)
+ parser.add_argument("--base-url", default="http://127.0.0.1:8000")
+ parser.add_argument("--timeout", type=float, default=240.0)
+ parser.add_argument("--top-n", type=int, default=50)
+ parser.add_argument("--output", type=Path)
+ parser.add_argument("--model", help="Override the model field from request_*.json")
+ parser.add_argument(
+ "--tokenizer",
+ help=(
+ "Tokenizer used to decode token_id: logprob keys before "
+ "comparing them with text token keys."
+ ),
+ )
+ parser.add_argument("--tokenizer-mode", default="auto")
+ parser.add_argument("--trust-remote-code", action="store_true")
+ parser.add_argument("--strict-tokens", action="store_true")
+ parser.add_argument("--strict-prompt-token-ids", action="store_true")
+ parser.add_argument("--fail-on-first-mismatch", action="store_true")
+ parser.add_argument("--min-top1-match-rate", type=float)
+ parser.add_argument("--min-topk-overlap-mean", type=float)
+ parser.add_argument("--max-common-logprob-abs-error", type=float)
+ parser.add_argument("--max-chosen-token-logprob-abs-error", type=float)
+ return parser.parse_args()
+
+
+def main() -> int:
+ args = parse_args()
+ cases = load_oracle_cases(args.oracle_dir)
+ normalizer = None
+ if args.tokenizer:
+ normalizer = load_token_normalizer(
+ args.tokenizer,
+ tokenizer_mode=args.tokenizer_mode,
+ trust_remote_code=args.trust_remote_code,
+ )
+ reports: list[Json] = []
+ for suffix, request_payload, oracle_response in cases:
+ payload = dict(request_payload)
+ if args.model:
+ payload["model"] = args.model
+ actual_response = post_completion(args.base_url, payload, args.timeout)
+ reports.append(
+ compare_response(
+ f"request_{suffix}",
+ oracle_response,
+ actual_response,
+ top_n=args.top_n,
+ normalizer=normalizer,
+ )
+ )
+
+ summary = summarize_reports(reports)
+ result: Json = {"summary": summary, "cases": reports}
+ text = json.dumps(result, indent=2, sort_keys=True)
+ if args.output:
+ args.output.write_text(text + "\n", encoding="utf-8")
+ print(text)
+ return 1 if _fails_thresholds(summary, reports, args) else 0
+
+
+if __name__ == "__main__":
+ try:
+ raise SystemExit(main())
+ except (OSError, RuntimeError, ValueError, json.JSONDecodeError) as exc:
+ print(f"error: {exc}", file=sys.stderr)
+ raise SystemExit(2) from exc
diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py
index 0f02a92681c1..91e5a8913e88 100644
--- a/vllm/config/compilation.py
+++ b/vllm/config/compilation.py
@@ -750,6 +750,7 @@ class CompilationConfig:
"vllm::sparse_attn_indexer",
"vllm::rocm_aiter_sparse_attn_indexer",
"vllm::deepseek_v4_attention",
+ "vllm::deepseek_v4_fp8_einsum",
]
def compute_hash(self) -> str:
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index cfe0857b679e..78de5bf2ef1e 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -399,6 +399,12 @@ class ConversationMessage(TypedDict, total=False):
reasoning_content: str | None
"""Deprecated: The reasoning content for interleaved thinking."""
+ prefix: bool
+ """Whether this assistant message is a prefix for continuation."""
+
+ wo_eos: bool
+ """Whether this message should be rendered without an EOS marker."""
+
tools: list[ChatCompletionFunctionToolParam] | None
"""The tools for developer role."""
@@ -1751,6 +1757,8 @@ def _parse_chat_message_content(
role = message["role"]
content = message.get("content")
reasoning = message.get("reasoning")
+ if reasoning is None:
+ reasoning = message.get("reasoning_content")
if content is None:
content = []
@@ -1780,6 +1788,9 @@ def _parse_chat_message_content(
result_msg["reasoning_content"] = cast(
str, reasoning
) # keep compatibility
+ if parsed_msg.get("prefix"):
+ result_msg["prefix"] = True
+ result_msg["wo_eos"] = True
elif role == "tool":
parsed_msg = _ToolParser(message)
if "tool_call_id" in parsed_msg:
diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py
index cc49909b8361..df6aff4a45c3 100644
--- a/vllm/entrypoints/openai/chat_completion/batch_serving.py
+++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py
@@ -159,8 +159,12 @@ async def create_batch_chat_completion(
self.override_max_tokens,
)
single_request = single_requests[i]
+ chat_template_kwargs = self._effective_chat_template_kwargs(single_request)
sampling_params = single_request.to_sampling_params(
- max_tokens, self.default_sampling_params
+ max_tokens,
+ self.default_sampling_params,
+ chat_template_kwargs=chat_template_kwargs,
+ model_config=self.model_config,
)
self._log_inputs(
sub_request_id,
diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py
index c92cc13da01f..b0a73770da35 100644
--- a/vllm/entrypoints/openai/chat_completion/protocol.py
+++ b/vllm/entrypoints/openai/chat_completion/protocol.py
@@ -62,6 +62,15 @@ class ChatMessage(OpenAIBaseModel):
# vLLM-specific fields that are not in OpenAI spec
reasoning: str | None = None
+ reasoning_content: str | None = None
+
+ @model_validator(mode="after")
+ def _populate_reasoning_content_alias(self) -> "ChatMessage":
+ if self.reasoning_content is None and self.reasoning is not None:
+ self.reasoning_content = self.reasoning
+ elif self.reasoning is None and self.reasoning_content is not None:
+ self.reasoning = self.reasoning_content
+ return self
class ChatCompletionLogProb(OpenAIBaseModel):
@@ -164,6 +173,10 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function"
+class DeepSeekThinkingParam(OpenAIBaseModel):
+ type: Literal["enabled", "disabled"] = "enabled"
+
+
class ChatCompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@@ -209,6 +222,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
"part of the standard OpenAI API specification."
),
)
+ thinking: DeepSeekThinkingParam | None = None
+ deepseek_v4_sampling_override: bool = Field(
+ default=True,
+ description=(
+ "Apply DeepSeek V4 official sampling defaults when thinking is "
+ "enabled. This only affects the DeepSeek V4 family and can be "
+ "disabled per request."
+ ),
+ )
thinking_token_budget: int | None = None
include_reasoning: bool = True
parallel_tool_calls: bool | None = True
@@ -435,21 +457,79 @@ def build_chat_params(
default_template: str | None,
default_template_content_format: ChatTemplateContentFormatOption,
) -> ChatParams:
+ chat_kwargs = merge_kwargs(
+ self.chat_template_kwargs,
+ dict(
+ add_generation_prompt=self.add_generation_prompt,
+ continue_final_message=self.continue_final_message,
+ documents=self.documents,
+ reasoning_effort=self.reasoning_effort,
+ ),
+ )
return ChatParams(
chat_template=self.chat_template or default_template,
chat_template_content_format=default_template_content_format,
- chat_template_kwargs=merge_kwargs(
- self.chat_template_kwargs,
- dict(
- add_generation_prompt=self.add_generation_prompt,
- continue_final_message=self.continue_final_message,
- documents=self.documents,
- reasoning_effort=self.reasoning_effort,
- ),
- ),
+ chat_template_kwargs=chat_kwargs,
media_io_kwargs=self.media_io_kwargs,
)
+ def _is_deepseek_v4_model(self, model_config: ModelConfig | None = None) -> bool:
+ hf_config = getattr(model_config, "hf_config", None)
+ if getattr(hf_config, "model_type", None) == "deepseek_v4":
+ return True
+
+ architectures = getattr(hf_config, "architectures", None) or ()
+ if any("deepseekv4" in str(arch).replace("_", "").lower()
+ for arch in architectures):
+ return True
+
+ model = (self.model or "").lower().replace("_", "-")
+ return "deepseek-v4" in model
+
+ def apply_chat_template_kwargs(
+ self,
+ chat_template_kwargs: dict[str, Any],
+ *,
+ model_config: ModelConfig | None = None,
+ ) -> dict[str, Any]:
+ """Apply request-level DeepSeek API compatibility knobs.
+
+ DeepSeek's OpenAI-compatible API exposes ``thinking`` as a top-level
+ request field, while vLLM's DeepSeek tokenizer consumes it as a chat
+ template kwarg. Keep the translation at the protocol boundary so the
+ tokenizer and reasoning parser see the same effective state.
+ """
+ chat_template_kwargs = dict(chat_template_kwargs)
+ if not self._is_deepseek_v4_model(model_config):
+ return chat_template_kwargs
+
+ if self.thinking is not None:
+ enabled = self.thinking.type == "enabled"
+ chat_template_kwargs["thinking"] = enabled
+ chat_template_kwargs["enable_thinking"] = enabled
+ elif (
+ "thinking" not in chat_template_kwargs
+ and "enable_thinking" not in chat_template_kwargs
+ ):
+ chat_template_kwargs["thinking"] = True
+ chat_template_kwargs["enable_thinking"] = True
+
+ return chat_template_kwargs
+
+ def _use_deepseek_v4_sampling_override(self) -> bool:
+ return self.deepseek_v4_sampling_override
+
+ @staticmethod
+ def _is_thinking_enabled(
+ chat_template_kwargs: dict[str, Any] | None,
+ ) -> bool:
+ if chat_template_kwargs is None:
+ return False
+ return bool(
+ chat_template_kwargs.get("thinking")
+ or chat_template_kwargs.get("enable_thinking")
+ )
+
def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams:
if self.max_completion_tokens is not None:
max_output_tokens: int | None = self.max_completion_tokens
@@ -499,6 +579,9 @@ def to_sampling_params(
self,
max_tokens: int,
default_sampling_params: dict,
+ *,
+ chat_template_kwargs: dict[str, Any] | None = None,
+ model_config: ModelConfig | None = None,
) -> SamplingParams:
# Default parameters
if (repetition_penalty := self.repetition_penalty) is None:
@@ -523,6 +606,21 @@ def to_sampling_params(
"min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]
)
+ if (
+ self._is_deepseek_v4_model(model_config)
+ and self._use_deepseek_v4_sampling_override()
+ and self._is_thinking_enabled(chat_template_kwargs)
+ ):
+ temperature = self._DEFAULT_SAMPLING_PARAMS["temperature"]
+ top_p = self._DEFAULT_SAMPLING_PARAMS["top_p"]
+ top_k = self._DEFAULT_SAMPLING_PARAMS["top_k"]
+ min_p = self._DEFAULT_SAMPLING_PARAMS["min_p"]
+ presence_penalty = 0.0
+ frequency_penalty = 0.0
+ else:
+ presence_penalty = self.presence_penalty or 0.0
+ frequency_penalty = self.frequency_penalty or 0.0
+
prompt_logprobs = self.prompt_logprobs
if prompt_logprobs is None and self.echo:
prompt_logprobs = self.top_logprobs
@@ -565,8 +663,8 @@ def to_sampling_params(
extra_args["kv_transfer_params"] = self.kv_transfer_params
return SamplingParams.from_optional(
n=self.n,
- presence_penalty=self.presence_penalty,
- frequency_penalty=self.frequency_penalty,
+ presence_penalty=presence_penalty,
+ frequency_penalty=frequency_penalty,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py
index 694ff80047c7..2dad7f948b5b 100644
--- a/vllm/entrypoints/openai/chat_completion/serving.py
+++ b/vllm/entrypoints/openai/chat_completion/serving.py
@@ -190,7 +190,7 @@ def warmup(self) -> None:
def _effective_chat_template_kwargs(
self, request: ChatCompletionRequest
) -> dict[str, Any]:
- return (
+ chat_template_kwargs = (
request.build_chat_params(
self.chat_template,
self.chat_template_content_format,
@@ -198,6 +198,10 @@ def _effective_chat_template_kwargs(
.with_defaults(self.default_chat_template_kwargs)
.chat_template_kwargs
)
+ return request.apply_chat_template_kwargs(
+ chat_template_kwargs,
+ model_config=self.model_config,
+ )
async def render_chat_request(
self,
@@ -300,6 +304,8 @@ async def create_chat_completion(
sampling_params = request.to_sampling_params(
max_tokens,
self.default_sampling_params,
+ chat_template_kwargs=chat_template_kwargs,
+ model_config=self.model_config,
)
self._log_inputs(
diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py
index 890af0300efc..8c531c6d5a93 100644
--- a/vllm/entrypoints/openai/engine/protocol.py
+++ b/vllm/entrypoints/openai/engine/protocol.py
@@ -268,8 +268,17 @@ class DeltaMessage(OpenAIBaseModel):
role: str | None = None
content: str | None = None
reasoning: str | None = None
+ reasoning_content: str | None = None
tool_calls: list[DeltaToolCall] = Field(default_factory=list)
+ @model_validator(mode="after")
+ def _populate_reasoning_content_alias(self) -> "DeltaMessage":
+ if self.reasoning_content is None and self.reasoning is not None:
+ self.reasoning_content = self.reasoning
+ elif self.reasoning is None and self.reasoning_content is not None:
+ self.reasoning = self.reasoning_content
+ return self
+
class GenerationError(Exception):
"""raised when finish_reason indicates internal server error (500)"""
diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py
index 967899229ada..d6da0dba6f4e 100644
--- a/vllm/entrypoints/serve/render/serving.py
+++ b/vllm/entrypoints/serve/render/serving.py
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
+from dataclasses import replace
from http import HTTPStatus
from typing import Any, cast
@@ -165,7 +166,21 @@ async def render_chat_request(
self.default_sampling_params,
self.override_max_tokens,
)
- params = request.to_sampling_params(max_tokens, self.default_sampling_params)
+ chat_template_kwargs = request.apply_chat_template_kwargs(
+ request.build_chat_params(
+ self.chat_template,
+ self.chat_template_content_format,
+ )
+ .with_defaults(self.default_chat_template_kwargs)
+ .chat_template_kwargs,
+ model_config=self.model_config,
+ )
+ params = request.to_sampling_params(
+ max_tokens,
+ self.default_sampling_params,
+ chat_template_kwargs=chat_template_kwargs,
+ model_config=self.model_config,
+ )
request_id = f"chatcmpl-{random_uuid()}"
@@ -556,6 +571,17 @@ async def preprocess_chat(
default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None),
default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None),
)
+ apply_chat_template_kwargs = getattr(
+ request, "apply_chat_template_kwargs", None
+ )
+ if apply_chat_template_kwargs is not None:
+ chat_params = replace(
+ chat_params,
+ chat_template_kwargs=apply_chat_template_kwargs(
+ chat_params.chat_template_kwargs,
+ model_config=self.model_config,
+ ),
+ )
(conversation,), (engine_input,) = await renderer.render_chat_async(
[messages],
diff --git a/vllm/envs.py b/vllm/envs.py
index ded474dc085a..07a3e55025de 100755
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -166,6 +166,12 @@
VLLM_MOE_USE_DEEP_GEMM: bool = True
VLLM_USE_DEEP_GEMM_E8M0: bool = True
VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True
+ VLLM_TRITON_MLA_SPARSE: bool | None = None
+ VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512
+ VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256
+ VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH: bool = True
+ VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None
+ VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None
VLLM_DEEP_GEMM_WARMUP: Literal[
"skip",
"full",
@@ -249,6 +255,7 @@
VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False
+ VLLM_DEEPSEEK_V4_USE_MEGA_MOE: bool = False
VLLM_LOG_MODEL_INSPECTION: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False
VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False
@@ -1275,6 +1282,34 @@ def _get_or_set_default() -> str:
"VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool(
int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1"))
),
+ # Experimental sparse MLA fallback controls.
+ # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse
+ # is unavailable; set 0/1 to force-disable/force-enable the fallback.
+ "VLLM_TRITON_MLA_SPARSE": lambda: (
+ None
+ if os.getenv("VLLM_TRITON_MLA_SPARSE") is None
+ else os.getenv("VLLM_TRITON_MLA_SPARSE", "").lower()
+ in ("1", "true", "yes", "on")
+ ),
+ "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE": lambda: maybe_convert_int(
+ os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512")
+ ),
+ "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE": lambda: maybe_convert_int(
+ os.getenv("VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", "256")
+ ),
+ "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH": lambda: (
+ os.getenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "1").lower()
+ in ("1", "true", "yes", "on")
+ ),
+ "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE": lambda: maybe_convert_int(
+ os.getenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE")
+ ),
+ "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE": lambda: (
+ None
+ if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None
+ else os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "").lower()
+ in ("1", "true", "yes", "on")
+ ),
# DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm
# JIT all the required kernels before model execution so there is no
# JIT'ing in the hot-path. However, this warmup increases the engine
@@ -1711,6 +1746,12 @@ def _get_or_set_default() -> str:
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
),
+ # Optional override for the DeepGEMM MegaMoE fused expert kernel in
+ # DeepSeek V4. If unset, kernel_config.moe_backend decides; set to 1/0 to
+ # force-enable or force-disable this path during bring-up.
+ "VLLM_DEEPSEEK_V4_USE_MEGA_MOE": lambda: bool(
+ int(os.getenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0"))
+ ),
# Log model inspection after loading.
# If enabled, logs a transformers-style hierarchical view of the model
# with quantization methods and attention backends.
diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
index 618084029159..6baedd3bbcbc 100644
--- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
+++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
@@ -7,6 +7,9 @@
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils import replace_parameter
+from vllm.model_executor.layers.quantization.utils.fp8_utils import (
+ _upcast_e8m0_to_fp32,
+)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
@@ -26,6 +29,20 @@
)
+def _is_sm12x_compute_capability(compute_capability) -> bool:
+ if compute_capability is None:
+ return current_platform.is_device_capability_family(120)
+
+ if isinstance(compute_capability, tuple):
+ return compute_capability[0] == 12
+
+ to_int = getattr(compute_capability, "to_int", None)
+ if callable(to_int):
+ return to_int() // 10 == 12
+
+ return int(compute_capability) // 10 == 12
+
+
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
@@ -196,6 +213,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
@classmethod
def is_supported(cls, compute_capability=None):
+ if _is_sm12x_compute_capability(compute_capability):
+ return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x."
+
if not CUTLASS_BLOCK_FP8_SUPPORTED:
return (
False,
@@ -219,6 +239,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig):
)
return True, None
+ def process_weights_after_loading(self, layer: torch.nn.Module):
+ super().process_weights_after_loading(layer)
+ params = self._get_layer_params(layer)
+ weight_scale = (
+ params.weight_scale
+ if params.weight_scale_inv is None
+ else params.weight_scale_inv
+ )
+ scale_attr_name = (
+ params.WEIGHT_SCALE
+ if params.weight_scale_inv is None
+ else params.WEIGHT_SCALE_INV
+ )
+ e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
+ if (
+ e8m0_dtype is not None
+ and weight_scale is not None
+ and weight_scale.dtype == e8m0_dtype
+ ):
+ replace_parameter(
+ layer,
+ scale_attr_name,
+ _upcast_e8m0_to_fp32(weight_scale),
+ )
+
def apply_block_scaled_mm(
self,
A: torch.Tensor,
diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py
index 494d61338084..ba520ce476a0 100644
--- a/vllm/model_executor/layers/deepseek_v4_attention.py
+++ b/vllm/model_executor/layers/deepseek_v4_attention.py
@@ -18,15 +18,21 @@
ReplicatedLinear,
)
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
+from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_einsum
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.ops.deepseek_v4_ops import (
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
+ dequantize_combined_sparse_mla_decode_kv,
fused_indexer_q_rope_quant,
fused_inv_rope_fp8_quant,
fused_q_kv_rmsnorm,
+ sparse_prefill_combined_topk_size,
+)
+from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import (
+ deepseek_v4_sm12x_fp8_einsum,
)
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_forward_decode_fallback,
@@ -44,7 +50,10 @@
VllmConfig,
get_current_vllm_config,
)
-from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.distributed import (
+ get_tensor_model_parallel_rank,
+ get_tensor_model_parallel_world_size,
+)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import PluggableLayer
@@ -58,7 +67,6 @@
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
-from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import (
execute_in_parallel,
maybe_execute_in_parallel,
@@ -73,6 +81,26 @@
DeepseekV4IndexerBackend,
get_max_prefill_buffer_size,
)
+from vllm.v1.attention.backends.mla.sparse_mla_env import (
+ disable_triton_sparse_mla_cudagraphs_if_enabled,
+ is_triton_sparse_mla_enabled,
+ is_triton_sparse_mla_enabled_for_platform,
+ triton_sparse_mla_matmul_decode_enabled,
+ triton_sparse_mla_query_chunk_size,
+ triton_sparse_mla_topk_chunk_size,
+)
+from vllm.v1.attention.backends.mla.sparse_mla_kernels import (
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead,
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead,
+ accumulate_indexed_sparse_mla_attention_chunk,
+ build_combined_sparse_mla_decode_valid_mask,
+ finish_sparse_mla_attention_with_sink,
+ finish_two_sparse_mla_attention_states_with_sink,
+ fp8ds_global_paged_sparse_mla_attention_with_sink_multihead,
+ fp8ds_paged_sparse_mla_attention_with_sink_multihead,
+ matmul_sparse_mla_attention_with_sink,
+ sparse_mla_decode_head_block_size,
+)
from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache
from vllm.v1.attention.ops.flashmla import (
flash_mla_sparse_fwd,
@@ -83,10 +111,88 @@
logger = init_logger(__name__)
+
+def _sparse_mla_prefill_workspace_bounds(
+ seq_lens_cpu: torch.Tensor,
+ gather_lens_cpu: torch.Tensor,
+ compress_ratio: int,
+ swa_only: bool,
+) -> tuple[int, int]:
+ if seq_lens_cpu.numel() == 0:
+ return 0, 0
+
+ max_gather_len = int(gather_lens_cpu.max().item())
+ if swa_only:
+ return 0, max_gather_len
+
+ compressed_region_size = int((seq_lens_cpu // compress_ratio).max().item())
+ return compressed_region_size, compressed_region_size + max_gather_len
+
+
+def _sparse_mla_prefill_gather_len_upper_bound(
+ *,
+ max_model_len: int,
+ max_num_batched_tokens: int,
+ window_size: int,
+) -> tuple[int, int]:
+ max_query_chunk_tokens = max(1, min(max_model_len, max_num_batched_tokens))
+ max_prefix_len = max(max_model_len - max_query_chunk_tokens, 0)
+ max_gather_len = max_query_chunk_tokens + min(
+ max_prefix_len,
+ max(window_size - 1, 0),
+ )
+ return max_query_chunk_tokens, max_gather_len
+
+
+def _deepseek_v4_fp8_einsum_config(
+ capability_major: int,
+) -> tuple[tuple[int, int, int], bool]:
+ if capability_major == 10:
+ return (1, 1, 128), True
+ return (1, 128, 128), False
+
+
+def _use_deepseek_v4_sm12x_triton_fp8_einsum(
+ equation: str,
+ recipe: list[int],
+ b_scale: torch.Tensor,
+) -> bool:
+ capability = current_platform.get_device_capability()
+ e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
+ return (
+ capability is not None
+ and capability.major == 12
+ and equation == "bhr,hdr->bhd"
+ and tuple(recipe) == (1, 128, 128)
+ and b_scale.dtype in (torch.float32, e8m0_dtype)
+ )
+
+
+def _allocate_deepseek_v4_wo_a_output(
+ num_tokens: int,
+ num_groups: int,
+ output_rank: int,
+ dtype: torch.dtype,
+ device: torch.device,
+) -> torch.Tensor:
+ shape = (num_tokens, num_groups, output_rank)
+ if torch.compiler.is_compiling():
+ # Workspace growth can call torch.accelerator.empty_cache(), which
+ # Dynamo intentionally refuses to trace. During compilation this is a
+ # normal graph allocation, matching the o_padded allocation above.
+ return torch.empty(shape, dtype=dtype, device=device)
+
+ (output,) = current_workspace_manager().get_simultaneous(
+ (shape, dtype),
+ )
+ return output
+
+
# Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather
# workspace allocated at _forward_prefill (and the matching profile-time
# reservation in attention_impl's dummy-run branch).
PREFILL_CHUNK_SIZE = 4
+_DEFAULT_SPARSE_MLA_TOPK_TOKENS = 2048
@dataclass
@@ -172,6 +278,8 @@ def __init__(
self.compress_ratio = compress_ratio if compress_ratio is not None else 1
self.prefix = prefix
+ disable_triton_sparse_mla_cudagraphs_if_enabled(mla_modules.vllm_config)
+
# Extract config from vllm_config
config = mla_modules.vllm_config.model_config.hf_config
tp_size = get_tensor_model_parallel_world_size()
@@ -202,12 +310,13 @@ def __init__(
self.wo_b = mla_modules.wo_b
# Pick fp8_einsum recipe based on GPU arch:
- # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128
- # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1
+ # SM90/SM120: FP32 block scales stay [g, r/128, d/128].
+ # SM100: INT32 packed scales become [g, r, ...].
cap = current_platform.get_device_capability()
assert cap is not None, "DeepseekV4 attention requires a CUDA device"
- self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128)
- self._tma_aligned_scales = cap.major >= 10
+ self._einsum_recipe, self._tma_aligned_scales = _deepseek_v4_fp8_einsum_config(
+ cap.major
+ )
self.rotary_emb = mla_modules.rotary_emb
self.indexer_rotary_emb = mla_modules.indexer_rotary_emb
@@ -336,10 +445,12 @@ def forward(
wo_a_fp8 = self.wo_a.weight
wo_a_scale = self.wo_a.weight_scale_inv
- z = torch.empty(
- (num_tokens, self.n_local_groups, self.o_lora_rank),
- device=o.device,
- dtype=torch.bfloat16,
+ z = _allocate_deepseek_v4_wo_a_output(
+ num_tokens,
+ self.n_local_groups,
+ self.o_lora_rank,
+ torch.bfloat16,
+ hidden_states.device,
)
torch.ops.vllm.deepseek_v4_fp8_einsum(
o_fp8,
@@ -494,21 +605,8 @@ def wq_b_kv_insert() -> torch.Tensor:
# Handle dummy run (no metadata).
if not isinstance(attn_metadata, dict):
- # Reserve _forward_prefill's bf16-gather workspace; the dummy
- # run returns before mla_attn runs, so without this the shared
- # workspace locks below the real prefill size.
- sub = self.mla_attn
- swa_only = sub.compress_ratio <= 1
- N = (
- 0
- if swa_only
- else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio
- )
- M = N + sub.window_size + sub.max_num_batched_tokens
- current_workspace_manager().get_simultaneous(
- ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
- )
out.zero_()
+ self.mla_attn._reserve_prefill_workspace()
return
# Pad q to FlashMLA-required head count (64 or 128)
@@ -594,6 +692,68 @@ def deepseek_v4_fp8_einsum(
equation: str,
recipe: list[int],
) -> None:
+ if equation == "bhr,hdr->bhd" and b.dim() == 2:
+ num_groups = out.shape[1]
+ out_rank = out.shape[2]
+ hidden_size = a.shape[2]
+ if b.shape[0] % out_rank != 0:
+ raise RuntimeError(
+ "DeepSeek V4 fp8 einsum weight rows must be divisible by "
+ f"out_rank={out_rank}, got {b.shape[0]}"
+ )
+ b_groups = b.shape[0] // out_rank
+ group_start = 0
+ if b_groups != num_groups:
+ if b_groups % num_groups != 0:
+ raise RuntimeError(
+ "DeepSeek V4 fp8 einsum weight groups must match the "
+ "TP-local output groups or be an integer multiple of "
+ f"them, got weight_groups={b_groups}, "
+ f"output_groups={num_groups}"
+ )
+ group_partitions = b_groups // num_groups
+ group_start = (
+ get_tensor_model_parallel_rank() % group_partitions
+ ) * num_groups
+ b = b.view(b_groups, out_rank, hidden_size)
+ if group_start != 0 or b_groups != num_groups:
+ b = b.narrow(0, group_start, num_groups)
+
+ if b_scale.dim() == 2:
+ scale_mn = recipe[1]
+ scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1
+ scale_k = recipe[2] * scale_k_pack
+ scale_out_blocks = (out_rank + scale_mn - 1) // scale_mn
+ scale_hidden_blocks = (hidden_size + scale_k - 1) // scale_k
+ if b_scale.shape[0] % scale_out_blocks != 0:
+ raise RuntimeError(
+ "DeepSeek V4 fp8 einsum scale rows must be divisible by "
+ f"scale_out_blocks={scale_out_blocks}, "
+ f"got {b_scale.shape[0]}"
+ )
+ scale_groups = b_scale.shape[0] // scale_out_blocks
+ if scale_groups not in (num_groups, b_groups):
+ raise RuntimeError(
+ "DeepSeek V4 fp8 einsum scale groups must match the "
+ "TP-local output groups or weight groups, got "
+ f"scale_groups={scale_groups}, output_groups={num_groups}, "
+ f"weight_groups={b_groups}"
+ )
+ b_scale = b_scale.view(
+ scale_groups,
+ scale_out_blocks,
+ scale_hidden_blocks,
+ )
+ if scale_groups == b_groups and scale_groups != num_groups:
+ b_scale = b_scale.narrow(0, group_start, num_groups)
+ elif b_scale.dim() == 3 and b_scale.shape[0] == b_groups:
+ if b_groups != num_groups:
+ b_scale = b_scale.narrow(0, group_start, num_groups)
+
+ if _use_deepseek_v4_sm12x_triton_fp8_einsum(equation, recipe, b_scale):
+ deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, out)
+ return
+
fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe))
@@ -711,7 +871,11 @@ def __init__(
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
- logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.")
+ logger.info_once(
+ "Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
+ "fp8 kv-cache format, please set `--attention-backend "
+ "FLASHINFER_MLA_SPARSE`"
+ )
self.kv_cache_dtype = kv_cache_dtype
@@ -724,6 +888,73 @@ def __init__(
self.kv_cache = torch.tensor([])
+ def _prefill_workspace_topk_bound(self) -> int:
+ if self.compress_ratio <= 1:
+ return 0
+ if (
+ self.topk_indices_buffer is not None
+ and self.topk_indices_buffer.ndim > 0
+ and self.topk_indices_buffer.shape[-1] > 0
+ ):
+ return int(self.topk_indices_buffer.shape[-1])
+ indexer_topk = getattr(self.indexer, "topk_tokens", None)
+ if indexer_topk is not None:
+ return int(indexer_topk)
+ return _DEFAULT_SPARSE_MLA_TOPK_TOKENS
+
+ def _prefill_workspace_reservation_specs(
+ self,
+ ) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]:
+ max_model_len = max(1, int(self.max_model_len))
+ max_num_batched_tokens = max(1, int(self.max_num_batched_tokens))
+ window_size = max(1, int(self.window_size))
+ compress_ratio = max(1, int(self.compress_ratio))
+ head_dim = int(self.head_dim)
+ num_heads = int(self.num_heads)
+
+ max_query_chunk_tokens, max_gather_len = (
+ _sparse_mla_prefill_gather_len_upper_bound(
+ max_model_len=max_model_len,
+ max_num_batched_tokens=max_num_batched_tokens,
+ window_size=window_size,
+ )
+ )
+ if compress_ratio <= 1:
+ m_bound = max_gather_len
+ else:
+ compressed_region_size = max_model_len // compress_ratio
+ m_bound = compressed_region_size + max_gather_len
+
+ combined_topk = sparse_prefill_combined_topk_size(
+ DeepseekV4MLAAttention._prefill_workspace_topk_bound(self),
+ window_size,
+ )
+ specs: list[tuple[tuple[int, ...], torch.dtype]] = [
+ ((PREFILL_CHUNK_SIZE, m_bound, head_dim), torch.bfloat16),
+ ((max_query_chunk_tokens, combined_topk), torch.int32),
+ ((max_query_chunk_tokens,), torch.int32),
+ ]
+ if is_triton_sparse_mla_enabled_for_platform():
+ query_chunk_size = min(
+ max_query_chunk_tokens,
+ triton_sparse_mla_query_chunk_size(),
+ )
+ specs.extend(
+ [
+ ((query_chunk_size, num_heads), torch.float32),
+ ((query_chunk_size, num_heads), torch.float32),
+ ((query_chunk_size, num_heads, head_dim), torch.float32),
+ ]
+ )
+ return tuple(specs)
+
+ def _reserve_prefill_workspace(self) -> None:
+ try:
+ workspace_manager = current_workspace_manager()
+ except AssertionError:
+ return
+ workspace_manager.get_simultaneous(*self._prefill_workspace_reservation_specs())
+
def get_attn_backend(self) -> type[AttentionBackend]:
return DeepseekV4FlashMLASparseBackend
@@ -743,6 +974,332 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
model_version="deepseek_v4",
)
+ def _forward_sparse_mla_swa_decode_triton(
+ self,
+ q: torch.Tensor,
+ swa_k_cache: torch.Tensor,
+ swa_metadata: "DeepseekSparseSWAMetadata",
+ output: torch.Tensor,
+ ) -> None:
+ num_decodes = swa_metadata.num_decodes
+ num_decode_tokens = swa_metadata.num_decode_tokens
+ mtp_decode = num_decode_tokens != num_decodes
+
+ swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens]
+ swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens]
+ max_swa_len = swa_metadata.decode_swa_indices.shape[-1]
+ head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens)
+ if not mtp_decode:
+ fp8ds_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ k_cache=swa_k_cache,
+ seq_lens=swa_metadata.seq_lens[:num_decodes],
+ gather_lens=swa_lens,
+ block_table=swa_metadata.block_table[:num_decodes],
+ block_size=swa_metadata.block_size,
+ candidate_offset=0,
+ num_candidates=max_swa_len,
+ scale=self.scale,
+ attn_sink=self.attn_sink,
+ output=output,
+ head_block_size=head_block_size,
+ num_heads=self.num_heads,
+ )
+ if output.shape[1] > self.num_heads:
+ output[:, self.num_heads :].zero_()
+ return
+
+ (
+ swa_max_score,
+ swa_denom,
+ swa_acc,
+ ) = current_workspace_manager().get_simultaneous(
+ ((num_decode_tokens, self.num_heads), torch.float32),
+ ((num_decode_tokens, self.num_heads), torch.float32),
+ ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32),
+ )
+ swa_max_score.fill_(float("-inf"))
+ swa_denom.zero_()
+ swa_acc.zero_()
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=swa_k_cache,
+ slot_ids=swa_indices,
+ lens=swa_lens,
+ block_size=swa_metadata.block_size,
+ scale=self.scale,
+ max_score=swa_max_score,
+ denom=swa_denom,
+ acc=swa_acc,
+ head_block_size=head_block_size,
+ )
+ finish_sparse_mla_attention_with_sink(
+ swa_max_score,
+ swa_denom,
+ swa_acc,
+ self.attn_sink,
+ output=output,
+ )
+ if output.shape[1] > self.num_heads:
+ output[:, self.num_heads :].zero_()
+
+ def _forward_sparse_mla_compressed_decode_triton(
+ self,
+ q: torch.Tensor,
+ compressed_k_cache: torch.Tensor,
+ swa_k_cache: torch.Tensor,
+ topk_indices: torch.Tensor,
+ topk_lens: torch.Tensor,
+ swa_metadata: "DeepseekSparseSWAMetadata",
+ attn_metadata: FlashMLASparseMetadata,
+ output: torch.Tensor,
+ ) -> None:
+ if self.compress_ratio not in (4, 128):
+ raise NotImplementedError(
+ "Triton sparse MLA compressed decode currently supports "
+ f"compress_ratio=4 or 128, got {self.compress_ratio}"
+ )
+
+ num_decodes = swa_metadata.num_decodes
+ num_decode_tokens = swa_metadata.num_decode_tokens
+ mtp_decode = num_decode_tokens != num_decodes
+
+ max_swa_len = swa_metadata.decode_swa_indices.shape[-1]
+ compressed_block_size = attn_metadata.block_size // self.compress_ratio
+ compressed_topk = topk_indices.shape[-1]
+ topk_chunk_size = min(
+ compressed_topk,
+ triton_sparse_mla_topk_chunk_size(),
+ )
+ compressed_slot_ids = topk_indices[:, 0, :]
+ swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens]
+ swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens]
+ head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens)
+ if (
+ not mtp_decode
+ and compressed_topk <= topk_chunk_size
+ and triton_sparse_mla_matmul_decode_enabled()
+ ):
+ total_candidates = compressed_topk + max_swa_len
+ (
+ combined_kv,
+ valid_tokens,
+ score_buffer,
+ ) = current_workspace_manager().get_simultaneous(
+ (
+ (num_decode_tokens, total_candidates, q.shape[-1]),
+ torch.bfloat16,
+ ),
+ ((num_decode_tokens, total_candidates), torch.bool),
+ ((num_decode_tokens, self.num_heads, total_candidates), torch.bfloat16),
+ )
+ dequantize_combined_sparse_mla_decode_kv(
+ combined_kv,
+ compressed_k_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ swa_k_cache,
+ swa_metadata.seq_lens[:num_decodes],
+ swa_lens,
+ swa_metadata.block_table[:num_decodes],
+ swa_metadata.block_size,
+ )
+
+ build_combined_sparse_mla_decode_valid_mask(
+ valid_tokens,
+ compressed_slot_ids,
+ topk_lens,
+ swa_lens,
+ )
+ use_dot_finish = num_decode_tokens <= 16
+ matmul_sparse_mla_attention_with_sink(
+ q=q,
+ kv=combined_kv,
+ valid_tokens=valid_tokens,
+ scale=self.scale,
+ attn_sink=self.attn_sink,
+ output=output,
+ num_heads=self.num_heads,
+ score_buffer=score_buffer,
+ value_block_size=512 if use_dot_finish else 256,
+ candidate_block_size=128 if use_dot_finish else None,
+ )
+ return
+
+ if not mtp_decode and compressed_topk <= topk_chunk_size:
+ fp8ds_global_paged_sparse_mla_attention_with_sink_multihead(
+ q=q,
+ compressed_k_cache=compressed_k_cache,
+ slot_ids=compressed_slot_ids,
+ topk_lens=topk_lens,
+ compressed_block_size=compressed_block_size,
+ swa_k_cache=swa_k_cache,
+ seq_lens=swa_metadata.seq_lens[:num_decodes],
+ gather_lens=swa_lens,
+ block_table=swa_metadata.block_table[:num_decodes],
+ swa_block_size=swa_metadata.block_size,
+ num_compressed_candidates=compressed_topk,
+ num_swa_candidates=max_swa_len,
+ scale=self.scale,
+ attn_sink=self.attn_sink,
+ output=output,
+ head_block_size=head_block_size,
+ num_heads=self.num_heads,
+ )
+ if output.shape[1] > self.num_heads:
+ output[:, self.num_heads :].zero_()
+ return
+
+ (
+ comp_max_score,
+ comp_denom,
+ comp_acc,
+ swa_max_score,
+ swa_denom,
+ swa_acc,
+ ) = current_workspace_manager().get_simultaneous(
+ ((num_decode_tokens, self.num_heads), torch.float32),
+ ((num_decode_tokens, self.num_heads), torch.float32),
+ ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32),
+ ((num_decode_tokens, self.num_heads), torch.float32),
+ ((num_decode_tokens, self.num_heads), torch.float32),
+ ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32),
+ )
+ comp_max_score.fill_(float("-inf"))
+ comp_denom.zero_()
+ comp_acc.zero_()
+ swa_max_score.fill_(float("-inf"))
+ swa_denom.zero_()
+ swa_acc.zero_()
+
+ for chunk_start in range(0, compressed_topk, topk_chunk_size):
+ chunk_end = min(chunk_start + topk_chunk_size, compressed_topk)
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=compressed_k_cache,
+ slot_ids=compressed_slot_ids[:, chunk_start:chunk_end],
+ lens=topk_lens,
+ block_size=compressed_block_size,
+ candidate_offset=chunk_start,
+ scale=self.scale,
+ max_score=comp_max_score,
+ denom=comp_denom,
+ acc=comp_acc,
+ head_block_size=head_block_size,
+ )
+ if mtp_decode:
+ accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=swa_k_cache,
+ slot_ids=swa_indices,
+ lens=swa_lens,
+ block_size=swa_metadata.block_size,
+ scale=self.scale,
+ max_score=swa_max_score,
+ denom=swa_denom,
+ acc=swa_acc,
+ head_block_size=head_block_size,
+ )
+ else:
+ accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q=q,
+ k_cache=swa_k_cache,
+ seq_lens=swa_metadata.seq_lens[:num_decodes],
+ gather_lens=swa_lens,
+ block_table=swa_metadata.block_table[:num_decodes],
+ block_size=swa_metadata.block_size,
+ candidate_offset=0,
+ num_candidates=max_swa_len,
+ scale=self.scale,
+ max_score=swa_max_score,
+ denom=swa_denom,
+ acc=swa_acc,
+ head_block_size=head_block_size,
+ )
+ finish_two_sparse_mla_attention_states_with_sink(
+ comp_max_score,
+ comp_denom,
+ comp_acc,
+ swa_max_score,
+ swa_denom,
+ swa_acc,
+ self.attn_sink,
+ output=output,
+ )
+ if output.shape[1] > self.num_heads:
+ output[:, self.num_heads :].zero_()
+
+ def _forward_sparse_mla_prefill_triton(
+ self,
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ combined_indices: torch.Tensor,
+ combined_lens: torch.Tensor,
+ output: torch.Tensor,
+ state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
+ ) -> None:
+ kv_flat = kv.reshape(-1, q.shape[-1])
+ topk_chunk_size = min(
+ combined_indices.shape[-1],
+ triton_sparse_mla_topk_chunk_size(),
+ )
+ query_chunk_size = min(
+ q.shape[0],
+ triton_sparse_mla_query_chunk_size(),
+ )
+ if state_buffers is None:
+ (
+ max_score_buffer,
+ denom_buffer,
+ output_buffer,
+ ) = current_workspace_manager().get_simultaneous(
+ ((query_chunk_size, self.num_heads), torch.float32),
+ ((query_chunk_size, self.num_heads), torch.float32),
+ ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32),
+ )
+ else:
+ max_score_buffer, denom_buffer, output_buffer = state_buffers
+
+ for token_start in range(0, q.shape[0], query_chunk_size):
+ token_end = min(token_start + query_chunk_size, q.shape[0])
+ q_chunk = q[token_start:token_end]
+ indices_chunk_full = combined_indices[token_start:token_end]
+ lens_chunk = combined_lens[token_start:token_end]
+ num_tokens = token_end - token_start
+ max_score = max_score_buffer[:num_tokens]
+ denom = denom_buffer[:num_tokens]
+ subset_acc = output_buffer[:num_tokens]
+ max_score.fill_(float("-inf"))
+ denom.zero_()
+ subset_acc.zero_()
+
+ for index_start in range(0, combined_indices.shape[-1], topk_chunk_size):
+ index_end = min(
+ index_start + topk_chunk_size,
+ combined_indices.shape[-1],
+ )
+ accumulate_indexed_sparse_mla_attention_chunk(
+ q=q_chunk,
+ kv_flat=kv_flat,
+ indices=indices_chunk_full[:, index_start:index_end],
+ lens=lens_chunk,
+ candidate_offset=index_start,
+ scale=self.scale,
+ max_score=max_score,
+ denom=denom,
+ acc=subset_acc,
+ )
+
+ finish_sparse_mla_attention_with_sink(
+ max_score,
+ denom,
+ subset_acc,
+ self.attn_sink,
+ output=output[token_start:token_end],
+ )
+ if output.shape[1] > self.num_heads:
+ output[token_start:token_end, self.num_heads :].zero_()
+
def forward(
self,
q: torch.Tensor,
@@ -823,12 +1380,14 @@ def _forward_decode(
if self.compress_ratio == 4:
# C4A: local indices differ per layer (filled by Indexer).
assert self.topk_indices_buffer is not None
+ local_topk_indices = self.topk_indices_buffer[:num_decode_tokens]
global_indices, topk_lens = compute_global_topk_indices_and_lens(
- self.topk_indices_buffer[:num_decode_tokens],
+ local_topk_indices,
swa_metadata.token_to_req_indices,
attn_metadata.block_table[:num_decodes],
block_size,
is_valid,
+ global_topk_indices=local_topk_indices,
)
topk_indices = global_indices.view(num_decode_tokens, 1, -1)
else:
@@ -867,9 +1426,35 @@ def _forward_decode(
# Use unsqueeze to preserve strides (handles padded blocks correctly)
swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2)
# Reshape KV cache to (num_blocks, block_size, 1, head_bytes)
+ compressed_k_cache = kv_cache
if kv_cache is not None:
kv_cache = kv_cache.unsqueeze(-2)
+ if is_triton_sparse_mla_enabled(q.device):
+ if swa_only:
+ self._forward_sparse_mla_swa_decode_triton(
+ q=q,
+ swa_k_cache=self.swa_cache_layer.kv_cache,
+ swa_metadata=swa_metadata,
+ output=output,
+ )
+ return
+ if self.compress_ratio in (4, 128):
+ assert compressed_k_cache is not None
+ assert attn_metadata is not None
+ assert topk_indices is not None
+ assert topk_lens is not None
+ self._forward_sparse_mla_compressed_decode_triton(
+ q=q,
+ compressed_k_cache=compressed_k_cache,
+ swa_k_cache=self.swa_cache_layer.kv_cache,
+ topk_indices=topk_indices,
+ topk_lens=topk_lens,
+ swa_metadata=swa_metadata,
+ attn_metadata=attn_metadata,
+ output=output,
+ )
+ return
# One FlashMLASchedMeta per layer type, shared across all same-type
# layers within this decode step. The first forward call per type
# triggers the in-kernel planner (allocating tile_scheduler_metadata
@@ -932,8 +1517,12 @@ def _forward_prefill(
# Use pre-computed prefill metadata.
seq_lens = swa_metadata.prefill_seq_lens
gather_lens = swa_metadata.prefill_gather_lens
+ seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu
+ gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu
assert seq_lens is not None
assert gather_lens is not None
+ assert seq_lens_cpu is not None
+ assert gather_lens_cpu is not None
# Derive prefill-local token offsets from the full query_start_loc_cpu.
query_start_loc_cpu = swa_metadata.query_start_loc_cpu
@@ -952,24 +1541,69 @@ def _forward_prefill(
assert attn_metadata is not None
topk_indices = attn_metadata.c128a_prefill_topk_indices
top_k = topk_indices.shape[-1]
- # Compressed region must fit the full compressed pool (seq_len //
- # compress_ratio), not just top_k. top_k bounds how many indices
- # the indexer selects, not the pool size it indexes into.
- N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio
else:
# NOTE(woosuk): topk_indices will not be used for SWA-only layers.
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[num_decode_tokens:]
top_k = 0
- N = 0
- M = N + self.window_size + self.max_num_batched_tokens
+ N, M = _sparse_mla_prefill_workspace_bounds(
+ seq_lens_cpu=seq_lens_cpu,
+ gather_lens_cpu=gather_lens_cpu,
+ compress_ratio=self.compress_ratio,
+ swa_only=swa_only,
+ )
num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE
+ max_query_chunk_tokens = 0
+ for chunk_idx in range(num_chunks):
+ chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
+ chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills)
+ query_start = (
+ query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base
+ )
+ query_end = (
+ query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base
+ )
+ max_query_chunk_tokens = max(
+ max_query_chunk_tokens, int(query_end - query_start)
+ )
+ combined_topk = sparse_prefill_combined_topk_size(top_k, self.window_size)
workspace_manager = current_workspace_manager()
- kv = workspace_manager.get_simultaneous(
- ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
- )[0]
+ triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device)
+ if triton_sparse_mla_enabled:
+ query_chunk_size = min(q.shape[0], triton_sparse_mla_query_chunk_size())
+ (
+ kv,
+ combined_indices_buffer,
+ combined_lens_buffer,
+ max_score_buffer,
+ denom_buffer,
+ output_buffer,
+ ) = workspace_manager.get_simultaneous(
+ ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
+ ((max_query_chunk_tokens, combined_topk), torch.int32),
+ ((max_query_chunk_tokens,), torch.int32),
+ ((query_chunk_size, self.num_heads), torch.float32),
+ ((query_chunk_size, self.num_heads), torch.float32),
+ ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32),
+ )
+ prefill_state_buffers = (
+ max_score_buffer,
+ denom_buffer,
+ output_buffer,
+ )
+ else:
+ (
+ kv,
+ combined_indices_buffer,
+ combined_lens_buffer,
+ ) = workspace_manager.get_simultaneous(
+ ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16),
+ ((max_query_chunk_tokens, combined_topk), torch.int32),
+ ((max_query_chunk_tokens,), torch.int32),
+ )
+ prefill_state_buffers = None
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * PREFILL_CHUNK_SIZE
chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills)
@@ -1008,6 +1642,7 @@ def _forward_prefill(
query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base
)
+ query_tokens = query_end - query_start
combined_indices, combined_lens = combine_topk_swa_indices(
topk_indices[query_start:query_end],
query_start_loc[
@@ -1020,8 +1655,21 @@ def _forward_prefill(
top_k,
M,
N,
+ combined_indices=combined_indices_buffer[:query_tokens],
+ combined_lens=combined_lens_buffer[:query_tokens],
)
+ if triton_sparse_mla_enabled:
+ self._forward_sparse_mla_prefill_triton(
+ q=q[query_start:query_end],
+ kv=kv[:chunk_size],
+ combined_indices=combined_indices,
+ combined_lens=combined_lens,
+ output=output[query_start:query_end],
+ state_buffers=prefill_state_buffers,
+ )
+ continue
+
if current_platform.is_rocm():
rocm_sparse_attn_prefill(
q=q[query_start:query_end],
@@ -1033,16 +1681,17 @@ def _forward_prefill(
attn_sink=self.attn_sink,
output=output[query_start:query_end],
)
- else:
- output_chunk, _, _ = flash_mla_sparse_fwd(
- q=q[query_start:query_end],
- kv=kv.view(-1, 1, q.shape[-1]),
- indices=combined_indices.unsqueeze(1),
- sm_scale=self.scale,
- attn_sink=self.attn_sink,
- topk_length=combined_lens,
- out=output[query_start:query_end],
- )
+ continue
+
+ output_chunk, _, _ = flash_mla_sparse_fwd(
+ q=q[query_start:query_end],
+ kv=kv.view(-1, 1, q.shape[-1]),
+ indices=combined_indices.unsqueeze(1),
+ sm_scale=self.scale,
+ attn_sink=self.attn_sink,
+ topk_length=combined_lens,
+ out=output[query_start:query_end],
+ )
class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase):
diff --git a/vllm/model_executor/layers/deepseek_v4_triton_kernels.py b/vllm/model_executor/layers/deepseek_v4_triton_kernels.py
new file mode 100644
index 000000000000..b5048c5fd013
--- /dev/null
+++ b/vllm/model_executor/layers/deepseek_v4_triton_kernels.py
@@ -0,0 +1,1282 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Triton fallback kernels used by the local DeepSeek V4 path."""
+
+import torch
+
+from vllm.triton_utils import LOG2E, tl, triton
+
+DEEPSEEK_V4_MLA_HEAD_DIM = 512
+FP8_DS_MLA_FP8_DIM = 448
+FP8_DS_MLA_SCALE_GROUP = 64
+FP8_DS_MLA_SCALE_BYTES = 8
+FP8_DS_MLA_TOKEN_BYTES = 576
+
+
+def _view_packed_fp8_paged_mqa_kv_cache(
+ kv_cache: torch.Tensor,
+ head_dim: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """Return FP8 values and fp32 scales from indexer cache block storage."""
+ if kv_cache.dtype != torch.uint8:
+ raise TypeError(f"Expected uint8 kv_cache, got {kv_cache.dtype}")
+ if kv_cache.dim() == 3:
+ num_blocks, block_size, head_dim_with_scale = kv_cache.shape
+ num_kv_heads = 1
+ elif kv_cache.dim() == 4:
+ num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape
+ else:
+ raise ValueError(
+ f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions"
+ )
+ if num_kv_heads != 1:
+ raise ValueError(f"Expected one KV head, got {num_kv_heads}")
+
+ scale_bytes = head_dim_with_scale - head_dim
+ if scale_bytes <= 0 or scale_bytes % torch.float32.itemsize != 0:
+ raise ValueError(
+ "Expected kv_cache last dimension to contain FP8 values followed "
+ f"by fp32 scale bytes; got head_dim={head_dim}, "
+ f"last_dim={head_dim_with_scale}"
+ )
+
+ block_stride = kv_cache.stride(0)
+ base_storage_offset = kv_cache.storage_offset()
+ scale_elems = scale_bytes // torch.float32.itemsize
+ kv_values = torch.as_strided(
+ kv_cache,
+ size=(num_blocks, block_size, 1, head_dim),
+ stride=(block_stride, head_dim, head_dim, 1),
+ storage_offset=base_storage_offset,
+ ).view(torch.float8_e4m3fn)
+ kv_scale = torch.as_strided(
+ kv_cache,
+ size=(num_blocks, block_size, 1, scale_bytes),
+ stride=(block_stride, scale_bytes, scale_bytes, 1),
+ storage_offset=base_storage_offset + block_size * head_dim,
+ ).view(torch.float32)
+ return kv_values, kv_scale[..., :scale_elems]
+
+
+@triton.jit
+def _sparse_attention_bf16_kernel(
+ q_ptr,
+ kv_ptr,
+ indices_ptr,
+ lengths_ptr,
+ sink_ptr,
+ out_ptr,
+ num_tokens: tl.constexpr,
+ num_heads: tl.constexpr,
+ seq_kv: tl.constexpr,
+ index_topk: tl.constexpr,
+ sm_scale_log2: tl.constexpr,
+ stride_qt: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ stride_kv_t: tl.constexpr,
+ stride_kv_d: tl.constexpr,
+ stride_indices_t: tl.constexpr,
+ stride_indices_k: tl.constexpr,
+ stride_out_t: tl.constexpr,
+ stride_out_h: tl.constexpr,
+ stride_out_d: tl.constexpr,
+ BLOCK_H: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ HAS_SINK: tl.constexpr,
+ LOG2E_CONST: tl.constexpr,
+):
+ token_id = tl.program_id(0)
+ head_block = tl.program_id(1)
+ heads = head_block * BLOCK_H + tl.arange(0, BLOCK_H)
+ offs_d = tl.arange(0, BLOCK_D)
+ mask_h = heads < num_heads
+
+ q = tl.load(
+ q_ptr
+ + token_id * stride_qt
+ + heads[:, None] * stride_qh
+ + offs_d[None, :] * stride_qd,
+ mask=mask_h[:, None],
+ other=0.0,
+ )
+
+ if HAS_SINK:
+ sink = tl.load(sink_ptr + heads, mask=mask_h, other=-float("inf"))
+ e_max = sink * LOG2E_CONST
+ e_sum = tl.where(mask_h, 1.0, 0.0)
+ else:
+ e_max = tl.full((BLOCK_H,), -float("inf"), dtype=tl.float32)
+ e_sum = tl.zeros((BLOCK_H,), dtype=tl.float32)
+ acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32)
+
+ length = tl.load(lengths_ptr + token_id)
+ for start in range(0, index_topk, BLOCK_N):
+ offs_n = start + tl.arange(0, BLOCK_N)
+ idx = tl.load(
+ indices_ptr + token_id * stride_indices_t + offs_n * stride_indices_k,
+ mask=offs_n < index_topk,
+ other=-1,
+ )
+ mask_kv = (offs_n < length) & (idx >= 0) & (idx < seq_kv)
+ k = tl.load(
+ kv_ptr + idx[None, :] * stride_kv_t + offs_d[:, None] * stride_kv_d,
+ mask=mask_kv[None, :],
+ other=0.0,
+ )
+ qk = tl.dot(q, k.to(q.dtype)) * sm_scale_log2
+ qk = tl.where(
+ mask_h[:, None] & mask_kv[None, :],
+ qk,
+ -3.4028234663852886e38,
+ )
+
+ v = tl.load(
+ kv_ptr + idx[:, None] * stride_kv_t + offs_d[None, :] * stride_kv_d,
+ mask=mask_kv[:, None],
+ other=0.0,
+ )
+
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
+ re_scale = tl.exp2(e_max - n_e_max)
+ p = tl.exp2(qk - n_e_max[:, None])
+ p = tl.where(mask_h[:, None] & mask_kv[None, :], p, 0.0)
+ acc = acc * re_scale[:, None] + tl.dot(p.to(v.dtype), v)
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ acc = acc / tl.maximum(e_sum, 1.0e-20)[:, None]
+ tl.store(
+ out_ptr
+ + token_id * stride_out_t
+ + heads[:, None] * stride_out_h
+ + offs_d[None, :] * stride_out_d,
+ acc.to(tl.bfloat16),
+ mask=mask_h[:, None],
+ )
+
+
+def sparse_attention_triton(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ indices: torch.Tensor,
+ lengths: torch.Tensor,
+ scale: float,
+ attn_sink: torch.Tensor | None,
+ out: torch.Tensor,
+) -> None:
+ if indices.ndim == 3:
+ indices = indices.squeeze(1)
+ if kv.ndim == 3:
+ kv = kv.squeeze(1)
+
+ num_tokens, num_heads, head_dim = q.shape
+ if num_tokens == 0:
+ return
+ if head_dim != DEEPSEEK_V4_MLA_HEAD_DIM:
+ raise ValueError(
+ "DeepSeek V4 sparse Triton fallback expects "
+ f"D={DEEPSEEK_V4_MLA_HEAD_DIM}, got {head_dim}"
+ )
+ assert kv.shape[-1] == head_dim
+ assert out.shape[-1] == head_dim
+
+ grid = (num_tokens, triton.cdiv(num_heads, 8))
+ _sparse_attention_bf16_kernel[grid](
+ q,
+ kv,
+ indices,
+ lengths,
+ attn_sink if attn_sink is not None else q,
+ out,
+ num_tokens,
+ num_heads,
+ kv.shape[0],
+ indices.shape[-1],
+ scale * LOG2E,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ kv.stride(0),
+ kv.stride(1),
+ indices.stride(0),
+ indices.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ BLOCK_H=8,
+ BLOCK_N=16,
+ BLOCK_D=DEEPSEEK_V4_MLA_HEAD_DIM,
+ HAS_SINK=attn_sink is not None,
+ LOG2E_CONST=LOG2E,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _decode_sparse_attention_fp8_kernel(
+ q_ptr,
+ swa_cache_fp8_ptr,
+ swa_cache_bf16_ptr,
+ swa_cache_u8_ptr,
+ swa_indices_ptr,
+ swa_lens_ptr,
+ extra_cache_fp8_ptr,
+ extra_cache_bf16_ptr,
+ extra_cache_u8_ptr,
+ extra_indices_ptr,
+ extra_lens_ptr,
+ sink_ptr,
+ out_ptr,
+ num_tokens: tl.constexpr,
+ num_heads: tl.constexpr,
+ swa_index_topk: tl.constexpr,
+ extra_index_topk: tl.constexpr,
+ swa_num_blocks: tl.constexpr,
+ extra_num_blocks: tl.constexpr,
+ swa_block_size: tl.constexpr,
+ extra_block_size: tl.constexpr,
+ swa_stride_block_bytes: tl.constexpr,
+ extra_stride_block_bytes: tl.constexpr,
+ sm_scale_log2: tl.constexpr,
+ stride_qt: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ stride_swa_indices_t: tl.constexpr,
+ stride_swa_indices_k: tl.constexpr,
+ stride_extra_indices_t: tl.constexpr,
+ stride_extra_indices_k: tl.constexpr,
+ stride_out_t: tl.constexpr,
+ stride_out_h: tl.constexpr,
+ stride_out_d: tl.constexpr,
+ BLOCK_H: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ FP8_DIM: tl.constexpr,
+ SCALE_GROUP: tl.constexpr,
+ SCALE_BYTES: tl.constexpr,
+ TOKEN_BYTES: tl.constexpr,
+ HAS_EXTRA: tl.constexpr,
+ HAS_SINK: tl.constexpr,
+ LOG2E_CONST: tl.constexpr,
+):
+ token_id = tl.program_id(0)
+ head_block = tl.program_id(1)
+ heads = head_block * BLOCK_H + tl.arange(0, BLOCK_H)
+ offs_d = tl.arange(0, BLOCK_D)
+ mask_h = heads < num_heads
+
+ q = tl.load(
+ q_ptr
+ + token_id * stride_qt
+ + heads[:, None] * stride_qh
+ + offs_d[None, :] * stride_qd,
+ mask=mask_h[:, None],
+ other=0.0,
+ )
+
+ if HAS_SINK:
+ sink = tl.load(sink_ptr + heads, mask=mask_h, other=-float("inf"))
+ e_max = sink * LOG2E_CONST
+ e_sum = tl.where(mask_h, 1.0, 0.0)
+ else:
+ e_max = tl.full((BLOCK_H,), -float("inf"), dtype=tl.float32)
+ e_sum = tl.zeros((BLOCK_H,), dtype=tl.float32)
+ acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32)
+
+ swa_len = tl.load(swa_lens_ptr + token_id)
+ extra_len = tl.load(extra_lens_ptr + token_id) if HAS_EXTRA else 0
+ total_len = extra_len + swa_len
+
+ for start in range(0, extra_index_topk + swa_index_topk, BLOCK_N):
+ offs_n = start + tl.arange(0, BLOCK_N)
+ use_extra = HAS_EXTRA & (offs_n < extra_len)
+ use_swa = (offs_n >= extra_len) & (offs_n < total_len)
+
+ extra_cols = offs_n
+ swa_cols = offs_n - extra_len
+ extra_idx = tl.load(
+ extra_indices_ptr
+ + token_id * stride_extra_indices_t
+ + extra_cols * stride_extra_indices_k,
+ mask=HAS_EXTRA & (extra_cols < extra_index_topk),
+ other=-1,
+ )
+ swa_idx = tl.load(
+ swa_indices_ptr
+ + token_id * stride_swa_indices_t
+ + swa_cols * stride_swa_indices_k,
+ mask=(swa_cols >= 0) & (swa_cols < swa_index_topk),
+ other=-1,
+ )
+ idx = tl.where(use_extra, extra_idx, swa_idx)
+
+ extra_block = idx // extra_block_size
+ extra_pos = idx - extra_block * extra_block_size
+ swa_block = idx // swa_block_size
+ swa_pos = idx - swa_block * swa_block_size
+ valid_extra = use_extra & (idx >= 0) & (extra_block < extra_num_blocks)
+ valid_swa = use_swa & (idx >= 0) & (swa_block < swa_num_blocks)
+ valid = valid_extra | valid_swa
+
+ extra_token_base = extra_block * extra_stride_block_bytes
+ extra_token_base += extra_pos * TOKEN_BYTES
+ swa_token_base = swa_block * swa_stride_block_bytes
+ swa_token_base += swa_pos * TOKEN_BYTES
+ token_base = tl.where(use_extra, extra_token_base, swa_token_base)
+ block_size = tl.where(use_extra, extra_block_size, swa_block_size)
+ stride_block_bytes = tl.where(
+ use_extra, extra_stride_block_bytes, swa_stride_block_bytes
+ )
+ pos = tl.where(use_extra, extra_pos, swa_pos)
+
+ is_fp8 = offs_d < FP8_DIM
+ scale_offsets = (
+ tl.where(use_extra, extra_block, swa_block)[:, None]
+ * stride_block_bytes[:, None]
+ + block_size[:, None] * TOKEN_BYTES
+ + pos[:, None] * SCALE_BYTES
+ + (offs_d[None, :] // SCALE_GROUP)
+ )
+ encoded_scale = tl.load(
+ tl.where(use_extra[:, None], extra_cache_u8_ptr, swa_cache_u8_ptr)
+ + scale_offsets,
+ mask=valid[:, None] & is_fp8[None, :],
+ other=127,
+ ).to(tl.float32)
+ fp8_scale = tl.exp2(encoded_scale - 127.0)
+
+ fp8_offsets = token_base[:, None] + offs_d[None, :]
+ fp8_vals = (
+ tl.load(
+ tl.where(use_extra[:, None], extra_cache_fp8_ptr, swa_cache_fp8_ptr)
+ + fp8_offsets,
+ mask=valid[:, None] & is_fp8[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ * fp8_scale
+ )
+
+ bf16_offsets = (token_base[:, None] + FP8_DIM) // 2
+ bf16_offsets += offs_d[None, :] - FP8_DIM
+ bf16_vals = tl.load(
+ tl.where(use_extra[:, None], extra_cache_bf16_ptr, swa_cache_bf16_ptr)
+ + bf16_offsets,
+ mask=valid[:, None] & (~is_fp8[None, :]),
+ other=0.0,
+ ).to(tl.float32)
+ k = tl.where(is_fp8[None, :], fp8_vals, bf16_vals)
+
+ qk = tl.dot(q, tl.trans(k.to(q.dtype))) * sm_scale_log2
+ qk = tl.where(
+ mask_h[:, None] & valid[None, :],
+ qk,
+ -3.4028234663852886e38,
+ )
+
+ n_e_max = tl.maximum(tl.max(qk, 1), e_max)
+ re_scale = tl.exp2(e_max - n_e_max)
+ p = tl.exp2(qk - n_e_max[:, None])
+ p = tl.where(mask_h[:, None] & valid[None, :], p, 0.0)
+ acc = acc * re_scale[:, None] + tl.dot(p.to(k.dtype), k)
+ e_sum = e_sum * re_scale + tl.sum(p, 1)
+ e_max = n_e_max
+
+ acc = acc / tl.maximum(e_sum, 1.0e-20)[:, None]
+ tl.store(
+ out_ptr
+ + token_id * stride_out_t
+ + heads[:, None] * stride_out_h
+ + offs_d[None, :] * stride_out_d,
+ acc.to(tl.bfloat16),
+ mask=mask_h[:, None],
+ )
+
+
+def decode_sparse_attention_triton(
+ q: torch.Tensor,
+ swa_cache: torch.Tensor,
+ swa_indices: torch.Tensor,
+ swa_lens: torch.Tensor,
+ scale: float,
+ attn_sink: torch.Tensor | None,
+ out: torch.Tensor,
+ extra_cache: torch.Tensor | None = None,
+ extra_indices: torch.Tensor | None = None,
+ extra_lens: torch.Tensor | None = None,
+) -> None:
+ if swa_indices.ndim == 3:
+ swa_indices = swa_indices.squeeze(1)
+ if extra_indices is not None and extra_indices.ndim == 3:
+ extra_indices = extra_indices.squeeze(1)
+
+ num_tokens, num_heads, head_dim = q.shape
+ if num_tokens == 0:
+ return
+ if head_dim != DEEPSEEK_V4_MLA_HEAD_DIM:
+ raise ValueError(
+ "DeepSeek V4 decode Triton fallback expects "
+ f"D={DEEPSEEK_V4_MLA_HEAD_DIM}, got {head_dim}"
+ )
+ has_extra = (
+ extra_cache is not None and extra_indices is not None and extra_lens is not None
+ )
+ if not has_extra:
+ extra_cache = swa_cache
+ extra_indices = swa_indices[:, :1]
+ extra_lens = swa_lens
+
+ assert extra_cache is not None
+ assert extra_indices is not None
+ assert extra_lens is not None
+ grid = (num_tokens, triton.cdiv(num_heads, 8))
+ _decode_sparse_attention_fp8_kernel[grid](
+ q,
+ swa_cache.view(torch.float8_e4m3fn),
+ swa_cache.view(torch.bfloat16),
+ swa_cache,
+ swa_indices,
+ swa_lens,
+ extra_cache.view(torch.float8_e4m3fn),
+ extra_cache.view(torch.bfloat16),
+ extra_cache,
+ extra_indices,
+ extra_lens,
+ attn_sink if attn_sink is not None else q,
+ out,
+ num_tokens,
+ num_heads,
+ swa_indices.shape[-1],
+ extra_indices.shape[-1] if has_extra else 0,
+ swa_cache.shape[0],
+ extra_cache.shape[0],
+ swa_cache.shape[1],
+ extra_cache.shape[1],
+ swa_cache.stride(0),
+ extra_cache.stride(0),
+ scale * LOG2E,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ swa_indices.stride(0),
+ swa_indices.stride(1),
+ extra_indices.stride(0),
+ extra_indices.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ BLOCK_H=8,
+ BLOCK_N=16,
+ BLOCK_D=DEEPSEEK_V4_MLA_HEAD_DIM,
+ FP8_DIM=FP8_DS_MLA_FP8_DIM,
+ SCALE_GROUP=FP8_DS_MLA_SCALE_GROUP,
+ SCALE_BYTES=FP8_DS_MLA_SCALE_BYTES,
+ TOKEN_BYTES=FP8_DS_MLA_TOKEN_BYTES,
+ HAS_EXTRA=has_extra,
+ HAS_SINK=attn_sink is not None,
+ LOG2E_CONST=LOG2E,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _deepseek_v4_fp8_einsum_triton_kernel(
+ a_ptr,
+ a_scale_ptr,
+ b_ptr,
+ b_scale_ptr,
+ out_ptr,
+ B: tl.constexpr,
+ G: tl.constexpr,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ a_stride_b: tl.constexpr,
+ a_stride_g: tl.constexpr,
+ a_stride_k: tl.constexpr,
+ as_stride_b: tl.constexpr,
+ as_stride_g: tl.constexpr,
+ as_stride_kb: tl.constexpr,
+ b_stride_g: tl.constexpr,
+ b_stride_n: tl.constexpr,
+ b_stride_k: tl.constexpr,
+ bs_stride_g: tl.constexpr,
+ bs_stride_nb: tl.constexpr,
+ bs_stride_kb: tl.constexpr,
+ out_stride_b: tl.constexpr,
+ out_stride_g: tl.constexpr,
+ out_stride_n: tl.constexpr,
+ BLOCK_B: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_b = tl.program_id(0)
+ pid_g = tl.program_id(1)
+ pid_n = tl.program_id(2)
+
+ offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ acc = tl.zeros((BLOCK_B, BLOCK_N), dtype=tl.float32)
+ for k0 in range(0, K, BLOCK_K):
+ k = k0 + offs_k
+ kb = k0 // BLOCK_K
+
+ a = tl.load(
+ a_ptr
+ + offs_b[:, None] * a_stride_b
+ + pid_g * a_stride_g
+ + k[None, :] * a_stride_k,
+ mask=(offs_b[:, None] < B) & (k[None, :] < K),
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptr
+ + pid_g * b_stride_g
+ + offs_n[:, None] * b_stride_n
+ + k[None, :] * b_stride_k,
+ mask=(offs_n[:, None] < N) & (k[None, :] < K),
+ other=0.0,
+ )
+ a_s = tl.load(
+ a_scale_ptr
+ + offs_b * as_stride_b
+ + pid_g * as_stride_g
+ + kb * as_stride_kb,
+ mask=offs_b < B,
+ other=0.0,
+ ).to(tl.float32)
+ b_s = tl.load(
+ b_scale_ptr
+ + pid_g * bs_stride_g
+ + (offs_n // BLOCK_K) * bs_stride_nb
+ + kb * bs_stride_kb,
+ mask=offs_n < N,
+ other=0.0,
+ ).to(tl.float32)
+ acc += (
+ tl.dot(a, tl.trans(b), out_dtype=tl.float32) * a_s[:, None] * b_s[None, :]
+ )
+
+ tl.store(
+ out_ptr
+ + offs_b[:, None] * out_stride_b
+ + pid_g * out_stride_g
+ + offs_n[None, :] * out_stride_n,
+ acc,
+ mask=(offs_b[:, None] < B) & (offs_n[None, :] < N),
+ )
+
+
+def _e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
+ return (scale.view(torch.uint8).to(torch.int32) << 23).view(torch.float32)
+
+
+def _unpack_int32_e8m0_scales(
+ packed_scale: torch.Tensor,
+ num_blocks: int,
+) -> torch.Tensor:
+ shifts = torch.arange(4, device=packed_scale.device, dtype=torch.int32) * 8
+ unpacked = (packed_scale.to(torch.int32).unsqueeze(-1) >> shifts) & 0xFF
+ unpacked = unpacked.reshape(*packed_scale.shape[:-1], -1)[..., :num_blocks]
+ return (unpacked << 23).view(torch.float32)
+
+
+def _normalize_deepseek_v4_fp8_einsum_inputs(
+ a: torch.Tensor,
+ a_scale: torch.Tensor,
+ b: torch.Tensor,
+ b_scale: torch.Tensor,
+ out: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ B, G, K = a.shape
+ _, out_g, N = out.shape
+ assert out_g == G
+ k_blocks = triton.cdiv(K, 128)
+ n_blocks = triton.cdiv(N, 128)
+
+ if b.ndim == 2:
+ b = b.view(G, N, K)
+ if b_scale.ndim == 2:
+ b_scale = b_scale.view(G, n_blocks, k_blocks)
+
+ if a_scale.dtype == torch.int32:
+ a_scale = _unpack_int32_e8m0_scales(a_scale, k_blocks)
+ if b_scale.dtype == torch.int32:
+ b_scale = _unpack_int32_e8m0_scales(b_scale, k_blocks)
+
+ if a_scale.dtype == torch.float8_e8m0fnu:
+ a_scale = _e8m0_to_fp32(a_scale)
+ if b_scale.dtype == torch.float8_e8m0fnu:
+ b_scale = _e8m0_to_fp32(b_scale)
+
+ return a, a_scale.contiguous(), b, b_scale.contiguous()
+
+
+def deepseek_v4_fp8_einsum_triton(
+ a: torch.Tensor,
+ a_scale: torch.Tensor,
+ b: torch.Tensor,
+ b_scale: torch.Tensor,
+ out: torch.Tensor,
+) -> None:
+ a, a_scale, b, b_scale = _normalize_deepseek_v4_fp8_einsum_inputs(
+ a, a_scale, b, b_scale, out
+ )
+ B, G, K = a.shape
+ N = out.shape[-1]
+ grid = (triton.cdiv(B, 16), G, triton.cdiv(N, 32))
+ _deepseek_v4_fp8_einsum_triton_kernel[grid](
+ a,
+ a_scale,
+ b,
+ b_scale,
+ out,
+ B,
+ G,
+ N,
+ K,
+ a.stride(0),
+ a.stride(1),
+ a.stride(2),
+ a_scale.stride(0),
+ a_scale.stride(1),
+ a_scale.stride(2),
+ b.stride(0),
+ b.stride(1),
+ b.stride(2),
+ b_scale.stride(0),
+ b_scale.stride(1),
+ b_scale.stride(2),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ BLOCK_B=16,
+ BLOCK_N=32,
+ BLOCK_K=128,
+ num_warps=4,
+ )
+
+
+@triton.jit
+def _fp8_mqa_logits_kernel(
+ q_ptr,
+ k_ptr,
+ scale_ptr,
+ weights_ptr,
+ cu_seqlen_ks_ptr,
+ cu_seqlen_ke_ptr,
+ logits_ptr,
+ num_q: tl.constexpr,
+ seq_len_kv: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ stride_qm: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ stride_kn: tl.constexpr,
+ stride_kd: tl.constexpr,
+ stride_wm: tl.constexpr,
+ stride_wh: tl.constexpr,
+ stride_lm: tl.constexpr,
+ stride_ln: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_d = tl.arange(0, BLOCK_D)
+
+ valid_m = offs_m < num_q
+ valid_n = offs_n < seq_len_kv
+ seq_start = tl.load(cu_seqlen_ks_ptr + offs_m, mask=valid_m, other=0)
+ seq_end = tl.load(cu_seqlen_ke_ptr + offs_m, mask=valid_m, other=0)
+ seq_mask = (offs_n[None, :] >= seq_start[:, None]) & (
+ offs_n[None, :] < seq_end[:, None]
+ )
+
+ logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for h in tl.range(0, num_heads):
+ scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for d0 in tl.range(0, head_dim, BLOCK_D):
+ d = d0 + offs_d
+ q = tl.load(
+ q_ptr
+ + offs_m[:, None] * stride_qm
+ + h * stride_qh
+ + d[None, :] * stride_qd,
+ mask=valid_m[:, None] & (d[None, :] < head_dim),
+ other=0.0,
+ ).to(tl.float32)
+ k = tl.load(
+ k_ptr + offs_n[:, None] * stride_kn + d[None, :] * stride_kd,
+ mask=valid_n[:, None] & (d[None, :] < head_dim),
+ other=0.0,
+ ).to(tl.float32)
+ scores += tl.dot(q, tl.trans(k), input_precision="tf32")
+ scale = tl.load(scale_ptr + offs_n, mask=valid_n, other=0.0)
+ weighted = tl.maximum(scores * scale[None, :], 0.0)
+ weight = tl.load(
+ weights_ptr + offs_m * stride_wm + h * stride_wh,
+ mask=valid_m,
+ other=0.0,
+ )
+ logits += weighted * weight[:, None]
+
+ store_mask = valid_m[:, None] & valid_n[None, :]
+ logits = tl.where(seq_mask & store_mask, logits, float("-inf"))
+ tl.store(
+ logits_ptr + offs_m[:, None] * stride_lm + offs_n[None, :] * stride_ln,
+ logits,
+ mask=store_mask,
+ )
+
+
+def fp8_mqa_logits_triton(
+ q: torch.Tensor,
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+) -> torch.Tensor:
+ k_fp8, scale = kv
+ num_q, num_heads, head_dim = q.shape
+ seq_len_kv = k_fp8.shape[0]
+ logits = torch.empty(
+ (num_q, seq_len_kv),
+ device=q.device,
+ dtype=torch.float32,
+ )
+ if num_q == 0 or seq_len_kv == 0:
+ return logits
+
+ grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 64))
+ _fp8_mqa_logits_kernel[grid](
+ q,
+ k_fp8,
+ scale,
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ logits,
+ num_q,
+ seq_len_kv,
+ num_heads,
+ head_dim,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ k_fp8.stride(0),
+ k_fp8.stride(1),
+ weights.stride(0),
+ weights.stride(1),
+ logits.stride(0),
+ logits.stride(1),
+ BLOCK_M=8,
+ BLOCK_N=64,
+ BLOCK_D=64,
+ num_warps=4,
+ )
+ return logits
+
+
+@triton.jit
+def _fp8_paged_mqa_logits_kernel(
+ q_ptr,
+ kv_ptr,
+ scale_ptr,
+ weights_ptr,
+ context_lens_ptr,
+ block_tables_ptr,
+ logits_ptr,
+ token_start,
+ num_rows: tl.constexpr,
+ logits_width: tl.constexpr,
+ next_n: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ block_size: tl.constexpr,
+ stride_qb: tl.constexpr,
+ stride_qn: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ stride_kvb: tl.constexpr,
+ stride_kvs: tl.constexpr,
+ stride_kvd: tl.constexpr,
+ stride_sb: tl.constexpr,
+ stride_ss: tl.constexpr,
+ stride_wm: tl.constexpr,
+ stride_wh: tl.constexpr,
+ stride_clb: tl.constexpr,
+ stride_cln: tl.constexpr,
+ stride_btb: tl.constexpr,
+ stride_btk: tl.constexpr,
+ stride_lm: tl.constexpr,
+ stride_ln: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_n = token_start + offs_local_n
+ offs_d = tl.arange(0, BLOCK_D)
+
+ valid_m = offs_m < num_rows
+ valid_n = offs_local_n < logits_width
+ batch = offs_m // next_n
+ q_pos = offs_m - batch * next_n
+ context_len = tl.load(
+ context_lens_ptr + batch * stride_clb + q_pos * stride_cln,
+ mask=valid_m,
+ other=0,
+ )
+ context_mask = valid_n[None, :] & (offs_n[None, :] < context_len[:, None])
+
+ block_rank = offs_n // block_size
+ block_offset = offs_n - block_rank * block_size
+ block_idx = tl.load(
+ block_tables_ptr
+ + batch[:, None] * stride_btb
+ + block_rank[None, :] * stride_btk,
+ mask=valid_m[:, None] & valid_n[None, :],
+ other=0,
+ )
+
+ logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ scale = tl.load(
+ scale_ptr + block_idx * stride_sb + block_offset[None, :] * stride_ss,
+ mask=context_mask,
+ other=0.0,
+ )
+ for h in tl.range(0, num_heads):
+ scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for d0 in tl.range(0, head_dim, BLOCK_D):
+ d = d0 + offs_d
+ q = tl.load(
+ q_ptr
+ + batch[:, None] * stride_qb
+ + q_pos[:, None] * stride_qn
+ + h * stride_qh
+ + d[None, :] * stride_qd,
+ mask=valid_m[:, None] & (d[None, :] < head_dim),
+ other=0.0,
+ ).to(tl.float32)
+ k = tl.load(
+ kv_ptr
+ + block_idx[:, :, None] * stride_kvb
+ + block_offset[None, :, None] * stride_kvs
+ + d[None, None, :] * stride_kvd,
+ mask=context_mask[:, :, None] & (d[None, None, :] < head_dim),
+ other=0.0,
+ ).to(tl.float32)
+ scores += tl.sum(q[:, None, :] * k, axis=2)
+ weighted = tl.maximum(scores * scale, 0.0)
+ weight = tl.load(
+ weights_ptr + offs_m * stride_wm + h * stride_wh,
+ mask=valid_m,
+ other=0.0,
+ )
+ logits += weighted * weight[:, None]
+
+ store_mask = valid_m[:, None] & valid_n[None, :]
+ logits = tl.where(context_mask & store_mask, logits, float("-inf"))
+ tl.store(
+ logits_ptr + offs_m[:, None] * stride_lm + offs_local_n[None, :] * stride_ln,
+ logits,
+ mask=store_mask,
+ )
+
+
+def fp8_paged_mqa_logits_triton(
+ q: torch.Tensor,
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+ token_start: int = 0,
+ token_count: int | None = None,
+) -> torch.Tensor:
+ batch_size, next_n, num_heads, head_dim = q.size()
+ if next_n == 1 and head_dim % 64 == 0 and num_heads % 4 == 0:
+ return fp8_paged_mqa_logits_rowwise_triton(
+ q,
+ kv_cache,
+ weights,
+ context_lens,
+ block_tables,
+ max_model_len,
+ token_start=token_start,
+ token_count=token_count,
+ )
+
+ kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim)
+ _, block_size, _, _ = kv_values.size()
+ num_rows = batch_size * next_n
+ if token_count is None:
+ token_count = max_model_len - token_start
+ assert token_start >= 0
+ assert token_count >= 0
+ assert token_start + token_count <= max_model_len
+ logits = torch.empty(
+ (num_rows, token_count),
+ device=q.device,
+ dtype=torch.float32,
+ )
+ if num_rows == 0 or token_count == 0:
+ return logits
+
+ context_lens_2d = context_lens.reshape(batch_size, -1)
+ if context_lens_2d.shape[1] == 1 and next_n != 1:
+ context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous()
+ grid = (triton.cdiv(num_rows, 4), triton.cdiv(token_count, 64))
+ _fp8_paged_mqa_logits_kernel[grid](
+ q,
+ kv_values,
+ kv_scale,
+ weights,
+ context_lens_2d,
+ block_tables,
+ logits,
+ token_start,
+ num_rows,
+ token_count,
+ next_n,
+ num_heads,
+ head_dim,
+ block_size,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ q.stride(3),
+ kv_values.stride(0),
+ kv_values.stride(1),
+ kv_values.stride(3),
+ kv_scale.stride(0),
+ kv_scale.stride(1),
+ weights.stride(0),
+ weights.stride(1),
+ context_lens_2d.stride(0),
+ context_lens_2d.stride(1),
+ block_tables.stride(0),
+ block_tables.stride(1),
+ logits.stride(0),
+ logits.stride(1),
+ BLOCK_M=4,
+ BLOCK_N=64,
+ BLOCK_D=64,
+ num_warps=4,
+ )
+ return logits
+
+
+@triton.jit
+def _fp8_paged_mqa_logits_rowwise_kernel(
+ q_ptr,
+ kv_ptr,
+ scale_ptr,
+ weights_ptr,
+ context_lens_ptr,
+ block_tables_ptr,
+ logits_ptr,
+ token_start,
+ num_rows: tl.constexpr,
+ logits_width: tl.constexpr,
+ next_n: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ block_size: tl.constexpr,
+ stride_qb: tl.constexpr,
+ stride_qn: tl.constexpr,
+ stride_qh: tl.constexpr,
+ stride_qd: tl.constexpr,
+ stride_kvb: tl.constexpr,
+ stride_kvs: tl.constexpr,
+ stride_kvd: tl.constexpr,
+ stride_sb: tl.constexpr,
+ stride_ss: tl.constexpr,
+ stride_wm: tl.constexpr,
+ stride_wh: tl.constexpr,
+ stride_clb: tl.constexpr,
+ stride_cln: tl.constexpr,
+ stride_btb: tl.constexpr,
+ stride_btk: tl.constexpr,
+ stride_lm: tl.constexpr,
+ stride_ln: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+ BLOCK_H: tl.constexpr,
+):
+ row = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_n = token_start + offs_local_n
+ offs_d = tl.arange(0, BLOCK_D)
+
+ valid_row = row < num_rows
+ valid_n = offs_local_n < logits_width
+ batch = row // next_n
+ q_pos = row - batch * next_n
+ context_len = tl.load(
+ context_lens_ptr + batch * stride_clb + q_pos * stride_cln,
+ mask=valid_row,
+ other=0,
+ )
+ context_mask = valid_n & (offs_n < context_len)
+
+ block_rank = offs_n // block_size
+ block_offset = offs_n - block_rank * block_size
+ block_idx = tl.load(
+ block_tables_ptr + batch * stride_btb + block_rank * stride_btk,
+ mask=valid_row & valid_n,
+ other=0,
+ )
+
+ scale = tl.load(
+ scale_ptr + block_idx * stride_sb + block_offset * stride_ss,
+ mask=context_mask,
+ other=0.0,
+ )
+ logits = tl.zeros((BLOCK_N,), dtype=tl.float32)
+
+ for h0 in tl.range(0, num_heads, BLOCK_H):
+ heads = h0 + tl.arange(0, BLOCK_H)
+ valid_h = heads < num_heads
+ scores = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32)
+ for d0 in tl.range(0, head_dim, BLOCK_D):
+ d = d0 + offs_d
+ q = tl.load(
+ q_ptr
+ + batch * stride_qb
+ + q_pos * stride_qn
+ + heads[:, None] * stride_qh
+ + d[None, :] * stride_qd,
+ mask=valid_row & valid_h[:, None] & (d[None, :] < head_dim),
+ other=0.0,
+ ).to(tl.float32)
+ k = tl.load(
+ kv_ptr
+ + block_idx[None, :] * stride_kvb
+ + block_offset[None, :] * stride_kvs
+ + d[:, None] * stride_kvd,
+ mask=context_mask[None, :] & (d[:, None] < head_dim),
+ other=0.0,
+ ).to(tl.float32)
+ scores += tl.dot(q, k, input_precision="tf32")
+
+ weighted = tl.maximum(scores * scale[None, :], 0.0)
+ weight = tl.load(
+ weights_ptr + row * stride_wm + heads * stride_wh,
+ mask=valid_row & valid_h,
+ other=0.0,
+ )
+ logits += tl.sum(weighted * weight[:, None], axis=0)
+
+ logits = tl.where(context_mask & valid_row, logits, float("-inf"))
+ tl.store(
+ logits_ptr + row * stride_lm + offs_local_n * stride_ln,
+ logits,
+ mask=valid_row & valid_n,
+ )
+
+
+def fp8_paged_mqa_logits_rowwise_triton(
+ q: torch.Tensor,
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+ token_start: int = 0,
+ token_count: int | None = None,
+) -> torch.Tensor:
+ batch_size, next_n, num_heads, head_dim = q.size()
+ kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim)
+ _, block_size, _, _ = kv_values.size()
+ num_rows = batch_size * next_n
+ if token_count is None:
+ token_count = max_model_len - token_start
+ assert token_start >= 0
+ assert token_count >= 0
+ assert token_start + token_count <= max_model_len
+ logits = torch.empty(
+ (num_rows, token_count),
+ device=q.device,
+ dtype=torch.float32,
+ )
+ if num_rows == 0 or token_count == 0:
+ return logits
+
+ context_lens_2d = context_lens.reshape(batch_size, -1)
+ if context_lens_2d.shape[1] == 1 and next_n != 1:
+ context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous()
+ block_n = 128
+ grid = (num_rows, triton.cdiv(token_count, block_n))
+ _fp8_paged_mqa_logits_rowwise_kernel[grid](
+ q,
+ kv_values,
+ kv_scale,
+ weights,
+ context_lens_2d,
+ block_tables,
+ logits,
+ token_start,
+ num_rows,
+ token_count,
+ next_n,
+ num_heads,
+ head_dim,
+ block_size,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ q.stride(3),
+ kv_values.stride(0),
+ kv_values.stride(1),
+ kv_values.stride(3),
+ kv_scale.stride(0),
+ kv_scale.stride(1),
+ weights.stride(0),
+ weights.stride(1),
+ context_lens_2d.stride(0),
+ context_lens_2d.stride(1),
+ block_tables.stride(0),
+ block_tables.stride(1),
+ logits.stride(0),
+ logits.stride(1),
+ BLOCK_N=block_n,
+ BLOCK_D=64,
+ BLOCK_H=8,
+ num_warps=4,
+ )
+ return logits
+
+
+@triton.jit
+def _tf32_hc_prenorm_gemm_kernel(
+ x_ptr,
+ fn_ptr,
+ out_ptr,
+ sqrsum_ptr,
+ M: tl.constexpr,
+ K: tl.constexpr,
+ N: tl.constexpr,
+ stride_xm: tl.constexpr,
+ stride_xk: tl.constexpr,
+ stride_fnn: tl.constexpr,
+ stride_fnk: tl.constexpr,
+ stride_outs: tl.constexpr,
+ stride_outm: tl.constexpr,
+ stride_outn: tl.constexpr,
+ stride_sqs: tl.constexpr,
+ stride_sqm: tl.constexpr,
+ NUM_SPLIT: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+):
+ pid_m = tl.program_id(0)
+ pid_n = tl.program_id(1)
+ pid_s = tl.program_id(2)
+
+ offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ offs_k = tl.arange(0, BLOCK_K)
+
+ split_k = tl.cdiv(K, NUM_SPLIT)
+ split_begin = pid_s * split_k
+ split_end = tl.minimum(split_begin + split_k, K)
+
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ sq = tl.zeros((BLOCK_M,), dtype=tl.float32)
+
+ for k0 in tl.range(0, split_k, BLOCK_K):
+ k = split_begin + k0 + offs_k
+ k_mask = k < split_end
+ x = tl.load(
+ x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk,
+ mask=(offs_m[:, None] < M) & k_mask[None, :],
+ other=0.0,
+ ).to(tl.float32)
+ fn = tl.load(
+ fn_ptr + offs_n[None, :] * stride_fnn + k[:, None] * stride_fnk,
+ mask=(offs_n[None, :] < N) & k_mask[:, None],
+ other=0.0,
+ ).to(tl.float32)
+
+ acc += tl.dot(x, fn, input_precision="tf32", out_dtype=tl.float32)
+ sq += tl.sum(x * x, axis=1)
+
+ tl.store(
+ out_ptr
+ + pid_s * stride_outs
+ + offs_m[:, None] * stride_outm
+ + offs_n[None, :] * stride_outn,
+ acc,
+ mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
+ )
+
+ if pid_n == 0:
+ tl.store(
+ sqrsum_ptr + pid_s * stride_sqs + offs_m * stride_sqm,
+ sq,
+ mask=offs_m < M,
+ )
+
+
+def tf32_hc_prenorm_gemm_triton(
+ x: torch.Tensor,
+ fn: torch.Tensor,
+ out: torch.Tensor,
+ sqrsum: torch.Tensor,
+ num_split: int,
+) -> None:
+ assert x.dim() == 2
+ assert fn.dim() == 2
+ assert out.dim() == 3
+ assert sqrsum.dim() == 2
+
+ m, k = x.shape
+ n = fn.shape[0]
+ assert fn.shape[1] == k
+ assert out.shape == (num_split, m, n)
+ assert sqrsum.shape == (num_split, m)
+
+ if m == 0:
+ return
+
+ block_m = 16
+ block_n = triton.next_power_of_2(n)
+ block_n = min(max(block_n, 16), 32)
+ block_k = 64
+ grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n), num_split)
+ _tf32_hc_prenorm_gemm_kernel[grid](
+ x,
+ fn,
+ out,
+ sqrsum,
+ m,
+ k,
+ n,
+ x.stride(0),
+ x.stride(1),
+ fn.stride(0),
+ fn.stride(1),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ sqrsum.stride(0),
+ sqrsum.stride(1),
+ num_split,
+ BLOCK_M=block_m,
+ BLOCK_N=block_n,
+ BLOCK_K=block_k,
+ num_warps=4,
+ )
diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
index 3487ac1766e6..416d871e24f4 100644
--- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
@@ -769,6 +769,7 @@ def apply(
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
+ clamp_limit=self.gemm1_clamp_limit,
)
return
diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py
index 456f40bbf7a3..79edfa6f2d92 100644
--- a/vllm/model_executor/layers/fused_moe/layer.py
+++ b/vllm/model_executor/layers/fused_moe/layer.py
@@ -786,6 +786,19 @@ def update_expert_map(self):
dp_size=get_dp_group().world_size,
)
+ @staticmethod
+ def _normalize_loaded_weight_for_copy(
+ expert_data: torch.Tensor, loaded_weight: torch.Tensor
+ ) -> torch.Tensor:
+ e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
+ if (
+ e8m0_dtype is not None
+ and expert_data.dtype == torch.uint8
+ and loaded_weight.dtype == e8m0_dtype
+ ):
+ return loaded_weight.view(torch.uint8)
+ return loaded_weight
+
def _load_per_tensor_weight_scale(
self,
shard_id: str,
@@ -799,10 +812,12 @@ def _load_per_tensor_weight_scale(
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
- param_data[expert_id][idx] = loaded_weight
+ target = param_data[expert_id][idx]
+ target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight))
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
- param_data[expert_id] = loaded_weight
+ target = param_data[expert_id]
+ target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight))
def _load_combined_w13_weight_scale(
self,
@@ -819,7 +834,7 @@ def _load_combined_w13_weight_scale(
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
)
- param.copy_(loaded_weight)
+ param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight))
def _load_model_weight_or_group_weight_scale(
self,
@@ -986,7 +1001,9 @@ def _load_w13(
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
- expert_data.copy_(loaded_weight)
+ expert_data.copy_(
+ self._normalize_loaded_weight_for_copy(expert_data, loaded_weight)
+ )
def _load_w2(
self,
@@ -1022,7 +1039,9 @@ def _load_w2(
hidden_dim=hidden_dim,
shard_dim=shard_dim,
)
- expert_data.copy_(loaded_weight)
+ expert_data.copy_(
+ self._normalize_loaded_weight_for_copy(expert_data, loaded_weight)
+ )
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
index d9aab35c25f4..e2df12488652 100644
--- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py
@@ -817,6 +817,35 @@ def get_w8a8_block_fp8_configs(
return None
+def _get_default_w8a8_block_fp8_config(
+ M: int,
+ block_n: int,
+ block_k: int,
+) -> dict[str, Any]:
+ # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and
+ # BLOCK_SIZE_K must be divisible by block_k.
+ # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the
+ # M-dim for single-request decode and short MTP-style draft batches. SM12x
+ # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes.
+ capability = current_platform.get_device_capability()
+ capability_major = getattr(capability, "major", None)
+ if capability_major is None and capability is not None:
+ capability_major = capability[0]
+ low_m_limit = 32 if capability_major == 12 else 8
+ if low_m_limit >= M:
+ block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3)
+ else:
+ block_m, num_stages = 64, 2
+ return {
+ "BLOCK_SIZE_M": block_m,
+ "BLOCK_SIZE_N": block_n,
+ "BLOCK_SIZE_K": block_k,
+ "GROUP_SIZE_M": 32,
+ "num_warps": 4,
+ "num_stages": num_stages,
+ }
+
+
def w8a8_triton_block_scaled_mm(
A: torch.Tensor,
B: torch.Tensor,
@@ -861,6 +890,12 @@ def w8a8_triton_block_scaled_mm(
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
+ e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
+ if e8m0_dtype is not None:
+ if As.dtype == e8m0_dtype:
+ As = _upcast_e8m0_to_fp32(As)
+ if Bs.dtype == e8m0_dtype:
+ Bs = _upcast_e8m0_to_fp32(Bs)
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
@@ -870,17 +905,7 @@ def w8a8_triton_block_scaled_mm(
# Get the optimal config if there is one
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
- # Default config
- # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
- # BLOCK_SIZE_K must be divisible by block_size[1]
- config = {
- "BLOCK_SIZE_M": 64,
- "BLOCK_SIZE_N": block_size[0],
- "BLOCK_SIZE_K": block_size[1],
- "GROUP_SIZE_M": 32,
- "num_warps": 4,
- "num_stages": 2,
- }
+ config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1])
def grid(META):
return (
@@ -1215,6 +1240,8 @@ def create_fp8_scale_parameter(
if dtype == torch.float32:
scale[:] = torch.finfo(torch.float32).min
+ elif dtype == getattr(torch, "float8_e8m0fnu", None):
+ scale[:] = 0
set_weight_attrs(scale, {"scale_type": "weight_scale"})
return scale
diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py
index 4bf52a49c43f..f974c873fdc1 100644
--- a/vllm/model_executor/layers/sparse_attn_indexer.py
+++ b/vllm/model_executor/layers/sparse_attn_indexer.py
@@ -4,7 +4,6 @@
import torch
-import vllm.envs as envs
from vllm._aiter_ops import rocm_aiter_ops
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
@@ -12,7 +11,9 @@
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
fp8_fp4_mqa_logits,
+ fp8_fp4_mqa_topk_indices,
fp8_fp4_paged_mqa_logits,
+ fp8_fp4_paged_mqa_topk_indices,
has_deep_gemm,
)
from vllm.utils.torch_utils import (
@@ -23,6 +24,7 @@
)
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadata,
+ sparse_indexer_max_logits_bytes,
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.worker.workspace import current_workspace_manager
@@ -35,11 +37,60 @@
logger = init_logger(__name__)
RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
+SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096
+SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288
# MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32.
MXFP4_BLOCK_SIZE = 32
+def _should_use_sm120_short_row_topk_decode(
+ topk_tokens: int,
+ logits_width: int,
+ num_rows: int,
+ is_cuda_sm120: bool,
+) -> bool:
+ if not is_cuda_sm120 or topk_tokens != 512:
+ return False
+ if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH:
+ return True
+ return logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH
+
+
+def _use_sm120_short_row_topk_decode(
+ logits: torch.Tensor,
+ topk_tokens: int,
+) -> bool:
+ return _should_use_sm120_short_row_topk_decode(
+ topk_tokens,
+ logits.shape[1],
+ logits.shape[0],
+ current_platform.is_cuda()
+ and current_platform.is_device_capability_family(120),
+ )
+
+
+def _decode_logits_width(max_model_len: int, max_seq_len: int) -> int:
+ if max_model_len <= 0:
+ return 0
+ if max_seq_len <= 0:
+ return max_model_len
+ return min(max_model_len, max_seq_len)
+
+
+def _decode_topk_logits_width(
+ max_model_len: int, max_seq_len: int, topk_tokens: int
+) -> int:
+ logits_width = _decode_logits_width(max_model_len, max_seq_len)
+ return min(max_model_len, max(logits_width, topk_tokens))
+
+
+def _sparse_indexer_requires_deep_gemm() -> bool:
+ return current_platform.is_cuda() and not (
+ current_platform.is_device_capability_family(120)
+ )
+
+
def _gather_workspace_shapes(
total_seq_lens: int,
head_dim: int,
@@ -118,7 +169,7 @@ def sparse_attn_indexer(
# Dummy allocation to simulate for peak logits tensor memory during inference.
# FP8 elements so elements == bytes
- max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
+ max_logits_elems = sparse_indexer_max_logits_bytes()
_ = torch.empty(
max_logits_elems, dtype=torch.uint8, device=hidden_states.device
)
@@ -220,6 +271,19 @@ def sparse_attn_indexer(
q_slice_cast = q_slice
k_quant_cast = k_quant
k_scale_cast = k_scale.view(torch.float32).squeeze(-1)
+ topk_indices = topk_indices_buffer[
+ chunk.token_start : chunk.token_end, :topk_tokens
+ ]
+ if fp8_fp4_mqa_topk_indices(
+ (q_slice_cast, q_scale_slice),
+ (k_quant_cast, k_scale_cast),
+ weights[chunk.token_start : chunk.token_end],
+ chunk.cu_seqlen_ks,
+ chunk.cu_seqlen_ke,
+ topk_indices,
+ ):
+ continue
+
logits = fp8_fp4_mqa_logits(
(q_slice_cast, q_scale_slice),
(k_quant_cast, k_scale_cast),
@@ -230,10 +294,6 @@ def sparse_attn_indexer(
)
num_rows = logits.shape[0]
- topk_indices = topk_indices_buffer[
- chunk.token_start : chunk.token_end, :topk_tokens
- ]
-
if current_platform.is_xpu():
xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined]
logits,
@@ -307,35 +367,38 @@ def sparse_attn_indexer(
if use_fp4_cache
else padded_q_quant_decode_tokens
)
- logits = fp8_fp4_paged_mqa_logits(
- (padded_q_quant_cast, padded_q_scale),
- kv_cache,
- weights[:num_padded_tokens],
- seq_lens,
- decode_metadata.block_table,
- decode_metadata.schedule_metadata,
- max_model_len=max_model_len,
- clean_logits=False,
- )
- num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
-
- if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048):
- workspace_manager = current_workspace_manager()
- (topk_workspace,) = workspace_manager.get_simultaneous(
- ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
- )
- torch.ops._C.persistent_topk(
- logits,
+ logits_width = _decode_topk_logits_width(
+ max_model_len, attn_metadata_narrowed.max_seq_len, topk_tokens
+ )
+ logits_bytes = num_padded_tokens * logits_width * torch.float32.itemsize
+ used_direct_topk = False
+ if logits_bytes > sparse_indexer_max_logits_bytes():
+ used_direct_topk = fp8_fp4_paged_mqa_topk_indices(
+ (padded_q_quant_cast, padded_q_scale),
+ kv_cache,
+ weights[:num_padded_tokens],
seq_lens,
+ decode_metadata.block_table,
+ logits_width,
topk_indices,
- topk_workspace,
- topk_tokens,
- attn_metadata_narrowed.max_seq_len,
)
- else:
- if current_platform.is_xpu():
- xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined]
+
+ if not used_direct_topk:
+ logits = fp8_fp4_paged_mqa_logits(
+ (padded_q_quant_cast, padded_q_scale),
+ kv_cache,
+ weights[:num_padded_tokens],
+ seq_lens,
+ decode_metadata.block_table,
+ decode_metadata.schedule_metadata,
+ max_model_len=logits_width,
+ clean_logits=False,
+ )
+ num_rows = logits.shape[0]
+
+ if _use_sm120_short_row_topk_decode(logits, topk_tokens):
+ torch.ops._C.top_k_per_row_decode(
logits,
next_n,
seq_lens,
@@ -345,17 +408,42 @@ def sparse_attn_indexer(
logits.stride(1),
topk_tokens,
)
- else:
- torch.ops._C.top_k_per_row_decode(
+ elif current_platform.is_cuda() and topk_tokens in (512, 2048):
+ workspace_manager = current_workspace_manager()
+ (topk_workspace,) = workspace_manager.get_simultaneous(
+ ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
+ )
+ torch.ops._C.persistent_topk(
logits,
- next_n,
seq_lens,
topk_indices,
- num_rows,
- logits.stride(0),
- logits.stride(1),
+ topk_workspace,
topk_tokens,
+ logits_width,
)
+ else:
+ if current_platform.is_xpu():
+ xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined]
+ logits,
+ next_n,
+ seq_lens,
+ topk_indices,
+ num_rows,
+ logits.stride(0),
+ logits.stride(1),
+ topk_tokens,
+ )
+ else:
+ torch.ops._C.top_k_per_row_decode(
+ logits,
+ next_n,
+ seq_lens,
+ topk_indices,
+ num_rows,
+ logits.stride(0),
+ logits.stride(1),
+ topk_tokens,
+ )
if decode_metadata.requires_padding:
# if padded, we need to unpack
@@ -438,7 +526,7 @@ def __init__(
self.topk_indices_buffer = topk_indices_buffer
self.skip_k_cache_insert = skip_k_cache_insert
self.use_fp4_cache = use_fp4_cache
- if current_platform.is_cuda() and not has_deep_gemm():
+ if _sparse_indexer_requires_deep_gemm() and not has_deep_gemm():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
)
diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py
index cef4038dc2e6..1c3adb3ac4b0 100644
--- a/vllm/model_executor/models/deepseek_v4.py
+++ b/vllm/model_executor/models/deepseek_v4.py
@@ -8,10 +8,12 @@
import torch
import torch.nn as nn
+from vllm import envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
+ get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
@@ -55,10 +57,14 @@
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
+from .interfaces import SupportsPP
from .utils import (
AutoWeightsLoader,
+ PPMissingLayer,
WeightsMapper,
extract_layer_index,
+ is_pp_missing_parameter,
+ make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
@@ -703,6 +709,25 @@ def _deepseek_v4_mega_moe_experts_op_fake(
)
+def _use_deepseek_v4_mega_moe(vllm_config: VllmConfig) -> bool:
+ use_mega_moe = (
+ vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
+ )
+
+ env_name = "VLLM_DEEPSEEK_V4_USE_MEGA_MOE"
+ if envs.is_set(env_name):
+ use_mega_moe = envs.VLLM_DEEPSEEK_V4_USE_MEGA_MOE
+
+ if use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
+ raise NotImplementedError(
+ "DeepSeek V4 MegaMoE currently requires expert parallel. "
+ "Enable it with --enable-expert-parallel, or pick a different "
+ "moe backend."
+ )
+
+ return use_mega_moe
+
+
class DeepseekV4MoE(nn.Module):
def __init__(
self,
@@ -715,15 +740,7 @@ def __init__(
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.prefix = prefix
- self.use_mega_moe = (
- vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
- )
- if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
- raise NotImplementedError(
- "DeepSeek V4 MegaMoE currently requires expert parallel. "
- "Enable it with --enable-expert-parallel, or pick a different "
- "moe backend."
- )
+ self.use_mega_moe = _use_deepseek_v4_mega_moe(vllm_config)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.hidden_size = config.hidden_size
@@ -1226,15 +1243,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
- self.use_mega_moe = (
- vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
- )
- if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
- raise NotImplementedError(
- "DeepSeek V4 MegaMoE currently requires expert parallel. "
- "Enable it with --enable-expert-parallel, or pick a different "
- "moe backend."
- )
+ self.use_mega_moe = _use_deepseek_v4_mega_moe(vllm_config)
self.vocab_size = config.vocab_size
self.hc_eps = config.hc_eps
self.hc_mult = config.hc_mult
@@ -1261,12 +1270,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
device=self.device,
)
- self.embed_tokens = VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- quant_config=quant_config,
- prefix=f"{prefix}.embed_tokens",
- )
+ if get_pp_group().is_first_rank:
+ self.embed_tokens = VocabParallelEmbedding(
+ config.vocab_size,
+ config.hidden_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.embed_tokens",
+ )
+ else:
+ self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
@@ -1279,7 +1291,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
prefix=f"{prefix}.layers",
)
- self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps)
+ if get_pp_group().is_last_rank:
+ self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps)
+ else:
+ self.norm = PPMissingLayer()
+
+ self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
+ ["hidden_states"], self.hc_dim
+ )
self.hc_head_fn = nn.Parameter(
torch.empty(
@@ -1304,26 +1323,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Pre-hc_head residual stream buffer for the MTP draft. Stable
# address (outside the cudagraph pool) so the copy_ in forward()
# refreshes it correctly across captured shapes.
- self._mtp_hidden_buffer = torch.empty(
- vllm_config.scheduler_config.max_num_batched_tokens,
- self.hc_dim,
- dtype=vllm_config.model_config.dtype,
- device=self.device,
- )
+ if get_pp_group().is_last_rank:
+ self._mtp_hidden_buffer = torch.empty(
+ vllm_config.scheduler_config.max_num_batched_tokens,
+ self.hc_dim,
+ dtype=vllm_config.model_config.dtype,
+ device=self.device,
+ )
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
- input_ids: torch.Tensor,
+ input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
- hidden_states = self.embed_input_ids(input_ids)
- hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
- if self.use_mega_moe:
+ if get_pp_group().is_first_rank:
+ if inputs_embeds is not None:
+ hidden_states = inputs_embeds
+ else:
+ hidden_states = self.embed_input_ids(input_ids)
+ hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
+ else:
+ assert intermediate_tensors is not None
+ hidden_states = intermediate_tensors["hidden_states"].view(
+ -1, self.hc_mult, self.config.hidden_size
+ )
+
+ if self.use_mega_moe and input_ids is not None:
input_ids = input_ids.to(torch.int64)
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(
@@ -1332,6 +1362,9 @@ def forward(
input_ids,
)
+ if not get_pp_group().is_last_rank:
+ return IntermediateTensors({"hidden_states": hidden_states.flatten(1)})
+
# Stash pre-hc_head residual for the MTP draft (captured copy_).
num_tokens = hidden_states.shape[0]
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
@@ -1379,6 +1412,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name, self):
+ break
param = params_dict[name]
weight_loader = param.weight_loader
@@ -1396,11 +1431,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
and loaded_weight.dtype == torch.float8_e8m0fnu
):
loaded_weight = loaded_weight.view(torch.uint8)
+ skip_expert_weight = False
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
+ if is_pp_missing_parameter(name_mapped, self):
+ skip_expert_weight = True
+ break
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
@@ -1419,15 +1458,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if success:
name = name_mapped
break
+ if skip_expert_weight:
+ continue
loaded_params.add(name_mapped)
continue
elif "attn_sink" in name:
+ if is_pp_missing_parameter(name, self):
+ continue
narrow_weight = loaded_weight[head_rank_start:head_rank_end]
n = narrow_weight.shape[0]
params_dict[name][:n].copy_(narrow_weight)
loaded_params.add(name)
continue
else:
+ if is_pp_missing_parameter(name, self):
+ continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
@@ -1525,7 +1570,7 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
)
-class DeepseekV4ForCausalLM(nn.Module):
+class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
model_cls = DeepseekV4Model
# Default mapper assumes the original FP4-expert checkpoint layout.
@@ -1544,12 +1589,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
- self.lm_head = ParallelLMHead(
- config.vocab_size,
- config.hidden_size,
- prefix=maybe_prefix(prefix, "lm_head"),
- )
+ if get_pp_group().is_last_rank:
+ self.lm_head = ParallelLMHead(
+ config.vocab_size,
+ config.hidden_size,
+ prefix=maybe_prefix(prefix, "lm_head"),
+ )
+ else:
+ self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
+ self.make_empty_intermediate_tensors = (
+ self.model.make_empty_intermediate_tensors
+ )
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@@ -1563,7 +1614,7 @@ def compute_logits(
def forward(
self,
- input_ids: torch.Tensor,
+ input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py
index cd51f106503a..e5962c7c0664 100644
--- a/vllm/reasoning/__init__.py
+++ b/vllm/reasoning/__init__.py
@@ -30,7 +30,7 @@
),
"deepseek_v4": (
"deepseek_v3_reasoning_parser",
- "DeepSeekV3ReasoningParser",
+ "DeepSeekV3ReasoningWithThinkingParser",
),
"poolside_v1": (
"poolside_v1_reasoning_parser",
diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py
index 6895771e2f59..74f01ce017e4 100644
--- a/vllm/tokenizers/deepseek_v4_encoding.py
+++ b/vllm/tokenizers/deepseek_v4_encoding.py
@@ -155,10 +155,15 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}{dsml_token}parameter>'
P_dsml_strs = []
- if isinstance(tool_call["arguments"], str):
- arguments = json.loads(tool_call["arguments"])
+ raw_arguments = tool_call.get("arguments")
+ if raw_arguments is None or raw_arguments == "":
+ arguments = {}
+ elif isinstance(raw_arguments, str):
+ arguments = json.loads(raw_arguments)
+ if arguments is None:
+ arguments = {}
else:
- arguments = tool_call["arguments"]
+ arguments = raw_arguments
for k, v in arguments.items():
p_dsml_str = p_dsml_template.format(
diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py
index 6b89f5c33203..7df618e58e1a 100644
--- a/vllm/utils/deep_gemm.py
+++ b/vllm/utils/deep_gemm.py
@@ -338,6 +338,253 @@ def transform_sf_into_required_layout(*args, **kwargs):
)
+_SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024
+_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192
+
+
+def _fp8_mqa_logits_head_chunk_size(
+ seq_len: int,
+ seq_len_kv: int,
+ num_heads: int,
+) -> int:
+ # The SM120 torch path is used on long prefill paths where materializing
+ # [head_chunk, M, N] scores can otherwise allocate multiple GiB. Keep the
+ # transient score tensor bounded, while still using larger head chunks for
+ # short prompts where they are faster.
+ score_elems_per_head = max(1, seq_len * seq_len_kv)
+ max_heads = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_head * 4)
+ return max(1, min(8, num_heads, max_heads))
+
+
+def _fp8_mqa_logits_k_chunk_size(
+ seq_len: int,
+ seq_len_kv: int,
+ head_chunk_size: int,
+) -> int:
+ score_elems_per_key = max(1, seq_len * head_chunk_size)
+ max_keys = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_key * 4)
+ return max(1, min(seq_len_kv, max_keys))
+
+
+def _fp8_mqa_logits_torch(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+ clean_logits: bool,
+) -> torch.Tensor:
+ q_values, q_scale = q
+ if q_scale is not None:
+ raise NotImplementedError("SM120 MQA logits torch path only supports FP8 Q")
+
+ k_values, k_scales = kv
+ k_f32 = k_values.to(torch.float32)
+ k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32))
+ k_t = k_f32.transpose(0, 1).contiguous()
+
+ seq_len, num_heads, _ = q_values.shape
+ seq_len_kv = k_f32.shape[0]
+ logits = torch.zeros(
+ (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32
+ )
+ head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads)
+
+ for head_start in range(0, num_heads, head_chunk_size):
+ head_end = min(head_start + head_chunk_size, num_heads)
+ q_chunk = q_values[:, head_start:head_end, :].to(torch.float32)
+ q_chunk = q_chunk.transpose(0, 1).contiguous()
+ head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1)
+ k_chunk_size = _fp8_mqa_logits_k_chunk_size(
+ seq_len, seq_len_kv, head_end - head_start
+ )
+ for k_start in range(0, seq_len_kv, k_chunk_size):
+ k_end = min(k_start + k_chunk_size, seq_len_kv)
+ scores = torch.matmul(q_chunk, k_t[:, k_start:k_end])
+ scores.relu_()
+ scores.mul_(head_weights)
+ logits[:, k_start:k_end].add_(
+ scores[0] if scores.shape[0] == 1 else scores.sum(dim=0)
+ )
+
+ if clean_logits:
+ offsets = torch.arange(seq_len_kv, device=q_values.device)
+ valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & (
+ offsets[None, :] < cu_seqlen_ke[:, None]
+ )
+ logits = logits.masked_fill(~valid, float("-inf"))
+
+ return logits
+
+
+def _fp8_mqa_logits_topk_torch(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+ topk_tokens: int,
+ out: torch.Tensor | None = None,
+) -> torch.Tensor:
+ q_values, q_scale = q
+ if q_scale is not None:
+ raise NotImplementedError("SM120 MQA top-k torch path only supports FP8 Q")
+
+ k_values, k_scales = kv
+ k_f32 = k_values.to(torch.float32)
+ k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32))
+ k_t = k_f32.transpose(0, 1).contiguous()
+
+ seq_len, num_heads, _ = q_values.shape
+ seq_len_kv = k_f32.shape[0]
+ if out is None:
+ out = torch.empty(
+ (seq_len, topk_tokens), device=q_values.device, dtype=torch.int32
+ )
+ else:
+ assert out.shape == (seq_len, topk_tokens)
+ assert out.dtype == torch.int32
+ out.fill_(-1)
+
+ best_values = torch.full(
+ (seq_len, topk_tokens),
+ float("-inf"),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads)
+ k_chunk_size = _fp8_mqa_logits_k_chunk_size(seq_len, seq_len_kv, head_chunk_size)
+ max_chunk_topk = min(topk_tokens, k_chunk_size)
+ chunk_values_buf = torch.empty(
+ (seq_len, max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ chunk_indices_buf = torch.empty(
+ (seq_len, max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.int64,
+ )
+ chunk_indices_i32 = torch.empty(
+ (seq_len, max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.int32,
+ )
+ candidate_values = torch.empty(
+ (seq_len, topk_tokens + max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ candidate_indices = torch.empty(
+ (seq_len, topk_tokens + max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.int32,
+ )
+ next_best_values = torch.empty_like(best_values)
+ selected = torch.empty(
+ (seq_len, topk_tokens),
+ device=q_values.device,
+ dtype=torch.int64,
+ )
+
+ for k_start in range(0, seq_len_kv, k_chunk_size):
+ k_end = min(k_start + k_chunk_size, seq_len_kv)
+ chunk_logits = torch.zeros(
+ (seq_len, k_end - k_start),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ for head_start in range(0, num_heads, head_chunk_size):
+ head_end = min(head_start + head_chunk_size, num_heads)
+ q_chunk = q_values[:, head_start:head_end, :].to(torch.float32)
+ q_chunk = q_chunk.transpose(0, 1).contiguous()
+ head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1)
+ scores = torch.matmul(q_chunk, k_t[:, k_start:k_end])
+ scores.relu_()
+ scores.mul_(head_weights)
+ chunk_logits.add_(scores[0] if scores.shape[0] == 1 else scores.sum(dim=0))
+
+ offsets = torch.arange(k_start, k_end, device=q_values.device)
+ valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & (
+ offsets[None, :] < cu_seqlen_ke[:, None]
+ )
+ chunk_logits.masked_fill_(~valid, float("-inf"))
+
+ chunk_topk = min(topk_tokens, k_end - k_start)
+ chunk_values = chunk_values_buf[:, :chunk_topk]
+ chunk_indices = chunk_indices_buf[:, :chunk_topk]
+ torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices))
+ chunk_indices_out = chunk_indices_i32[:, :chunk_topk]
+ chunk_indices_out.copy_(chunk_indices)
+ chunk_indices_out.add_(k_start)
+
+ candidate_cols = topk_tokens + chunk_topk
+ candidate_values_view = candidate_values[:, :candidate_cols]
+ candidate_indices_view = candidate_indices[:, :candidate_cols]
+ candidate_values_view[:, :topk_tokens].copy_(best_values)
+ candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values)
+ candidate_indices_view[:, :topk_tokens].copy_(out)
+ candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out)
+ torch.topk(
+ candidate_values_view,
+ topk_tokens,
+ dim=1,
+ out=(next_best_values, selected),
+ )
+ torch.gather(candidate_indices_view, 1, selected, out=out)
+ best_values, next_best_values = next_best_values, best_values
+ out.masked_fill_(~torch.isfinite(best_values), -1)
+
+ return out
+
+
+def fp8_fp4_mqa_topk_indices(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+ topk_indices: torch.Tensor,
+) -> bool:
+ """Write SM120 FP8 MQA top-k indices without materializing full logits."""
+ if not (
+ current_platform.is_cuda()
+ and current_platform.is_device_capability_family(120)
+ and q[1] is None
+ ):
+ return False
+ _fp8_mqa_logits_topk_torch(
+ q,
+ kv,
+ weights,
+ cu_seqlen_ks,
+ cu_seqlen_ke,
+ topk_indices.shape[1],
+ out=topk_indices,
+ )
+ return True
+
+
+def _fp8_mqa_logits_sm12x(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv: tuple[torch.Tensor, torch.Tensor],
+ weights: torch.Tensor,
+ cu_seqlen_ks: torch.Tensor,
+ cu_seqlen_ke: torch.Tensor,
+ clean_logits: bool,
+) -> torch.Tensor:
+ q_values, q_scale = q
+ if clean_logits and q_scale is None and q_values.dim() == 3 and kv[0].dim() == 2:
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ fp8_mqa_logits_triton,
+ )
+
+ return fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
+ return _fp8_mqa_logits_torch(
+ q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits
+ )
+
+
def fp8_fp4_mqa_logits(
q: tuple[torch.Tensor, torch.Tensor | None],
kv: tuple[torch.Tensor, torch.Tensor],
@@ -370,6 +617,10 @@ def fp8_fp4_mqa_logits(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
+ if current_platform.is_device_capability_family(120) and q[1] is None:
+ return _fp8_mqa_logits_sm12x(
+ q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits
+ )
_lazy_init()
if _fp8_fp4_mqa_logits_impl is None:
return _missing()
@@ -404,6 +655,215 @@ def get_paged_mqa_logits_metadata(
return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)
+def _fp8_paged_mqa_logits_torch(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+) -> torch.Tensor:
+ q_values, q_scale = q
+ if q_scale is not None:
+ raise NotImplementedError("SM120 paged MQA torch path only supports FP8 Q")
+
+ batch_size, next_n, num_heads, head_dim = q_values.shape
+ head_dim_with_scale = kv_cache.shape[-1]
+ assert head_dim_with_scale > head_dim
+ assert weights.shape == (batch_size * next_n, num_heads)
+ assert context_lens.shape == (batch_size, next_n)
+
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ _view_packed_fp8_paged_mqa_kv_cache,
+ )
+
+ kv_values, kv_scales = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim)
+ _, block_kv, _, _ = kv_values.shape
+ logits = torch.full(
+ (batch_size * next_n, max_model_len),
+ float("-inf"),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+
+ q_f32 = q_values.float()
+ score_bytes = _SM120_MQA_LOGITS_MAX_SCORE_BYTES
+ max_tokens_per_chunk = max(1, score_bytes // max(1, num_heads * 4))
+ token_offsets_cache: dict[int, torch.Tensor] = {}
+
+ for batch_idx in range(batch_size):
+ for next_idx in range(next_n):
+ row = batch_idx * next_n + next_idx
+ context_len = int(context_lens[batch_idx, next_idx].item())
+ if context_len <= 0:
+ continue
+
+ q_row = q_f32[batch_idx, next_idx]
+ row_weights = weights[row]
+ for token_start in range(0, context_len, max_tokens_per_chunk):
+ token_end = min(context_len, token_start + max_tokens_per_chunk)
+ chunk_len = token_end - token_start
+ token_offsets = token_offsets_cache.get(chunk_len)
+ if token_offsets is None or token_offsets.device != q_values.device:
+ token_offsets = torch.arange(
+ chunk_len, device=q_values.device, dtype=torch.long
+ )
+ token_offsets_cache[chunk_len] = token_offsets
+ token_ids = token_start + token_offsets
+ logical_blocks = token_ids // block_kv
+ token_in_block = token_ids - logical_blocks * block_kv
+ physical_blocks = block_tables[batch_idx, logical_blocks]
+ kv_chunk = kv_values[physical_blocks, token_in_block, 0].float()
+ scale_chunk = kv_scales[physical_blocks, token_in_block, 0].squeeze(-1)
+ kv_chunk.mul_(scale_chunk[:, None])
+ scores = torch.matmul(q_row, kv_chunk.T)
+ scores.relu_()
+ scores.mul_(row_weights[:, None])
+ logits[row, token_start:token_end] = scores.sum(dim=0)
+
+ return logits
+
+
+def _fp8_paged_mqa_logits_sm12x(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+) -> torch.Tensor:
+ q_values, q_scale = q
+ if (
+ q_scale is None
+ and q_values.dim() == 4
+ and kv_cache.dtype == torch.uint8
+ and kv_cache.shape[-1] == q_values.shape[-1] + 4
+ ):
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ fp8_paged_mqa_logits_triton,
+ )
+
+ return fp8_paged_mqa_logits_triton(
+ q_values, kv_cache, weights, context_lens, block_tables, max_model_len
+ )
+ return _fp8_paged_mqa_logits_torch(
+ q, kv_cache, weights, context_lens, block_tables, max_model_len
+ )
+
+
+def fp8_fp4_paged_mqa_topk_indices(
+ q: tuple[torch.Tensor, torch.Tensor | None],
+ kv_cache: torch.Tensor,
+ weights: torch.Tensor,
+ context_lens: torch.Tensor,
+ block_tables: torch.Tensor,
+ max_model_len: int,
+ topk_indices: torch.Tensor,
+) -> bool:
+ """Write SM120 FP8 paged MQA top-k indices without full logits."""
+ q_values, q_scale = q
+ if not (
+ current_platform.is_cuda()
+ and current_platform.is_device_capability_family(120)
+ and q_scale is None
+ and q_values.dim() == 4
+ and kv_cache.dtype == torch.uint8
+ and kv_cache.shape[-1] == q_values.shape[-1] + 4
+ ):
+ return False
+
+ num_rows = q_values.shape[0] * q_values.shape[1]
+ topk_tokens = topk_indices.shape[1]
+ assert topk_indices.shape == (num_rows, topk_tokens)
+ assert topk_indices.dtype == torch.int32
+ topk_indices.fill_(-1)
+ if num_rows == 0 or topk_tokens == 0 or max_model_len == 0:
+ return True
+
+ best_values = torch.full(
+ (num_rows, topk_tokens),
+ float("-inf"),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ chunk_size = max(1, _SM120_PAGED_MQA_TOPK_CHUNK_SIZE)
+ max_chunk_topk = min(topk_tokens, chunk_size)
+ chunk_values_buf = torch.empty(
+ (num_rows, max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ chunk_indices_buf = torch.empty(
+ (num_rows, max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.int64,
+ )
+ chunk_indices_i32 = torch.empty(
+ (num_rows, max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.int32,
+ )
+ candidate_values = torch.empty(
+ (num_rows, topk_tokens + max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.float32,
+ )
+ candidate_indices = torch.empty(
+ (num_rows, topk_tokens + max_chunk_topk),
+ device=q_values.device,
+ dtype=torch.int32,
+ )
+ next_best_values = torch.empty_like(best_values)
+ selected = torch.empty(
+ (num_rows, topk_tokens),
+ device=q_values.device,
+ dtype=torch.int64,
+ )
+
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ fp8_paged_mqa_logits_triton,
+ )
+
+ for token_start in range(0, max_model_len, chunk_size):
+ token_count = min(chunk_size, max_model_len - token_start)
+ chunk_logits = fp8_paged_mqa_logits_triton(
+ q_values,
+ kv_cache,
+ weights,
+ context_lens,
+ block_tables,
+ max_model_len,
+ token_start=token_start,
+ token_count=token_count,
+ )
+ chunk_topk = min(topk_tokens, token_count)
+ chunk_values = chunk_values_buf[:, :chunk_topk]
+ chunk_indices = chunk_indices_buf[:, :chunk_topk]
+ torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices))
+ chunk_indices_out = chunk_indices_i32[:, :chunk_topk]
+ chunk_indices_out.copy_(chunk_indices)
+ chunk_indices_out.add_(token_start)
+
+ candidate_cols = topk_tokens + chunk_topk
+ candidate_values_view = candidate_values[:, :candidate_cols]
+ candidate_indices_view = candidate_indices[:, :candidate_cols]
+ candidate_values_view[:, :topk_tokens].copy_(best_values)
+ candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values)
+ candidate_indices_view[:, :topk_tokens].copy_(topk_indices)
+ candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out)
+ torch.topk(
+ candidate_values_view,
+ topk_tokens,
+ dim=1,
+ out=(next_best_values, selected),
+ )
+ torch.gather(candidate_indices_view, 1, selected, out=topk_indices)
+ best_values, next_best_values = next_best_values, best_values
+ topk_indices.masked_fill_(~torch.isfinite(best_values), -1)
+
+ return True
+
+
def fp8_fp4_paged_mqa_logits(
q: tuple[torch.Tensor, torch.Tensor | None],
kv_cache: torch.Tensor,
@@ -425,9 +885,10 @@ def fp8_fp4_paged_mqa_logits(
[B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path:
q_values is packed uint8 and q_scale is the companion
block-scale tensor.
- kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1,
- D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos)
- storing the float dequant scale.
+ kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, D+4]
+ or [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. Within
+ each block, the D-byte FP8 values for every token are stored first,
+ followed by per-token fp32 scale bytes.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
@@ -442,6 +903,10 @@ def fp8_fp4_paged_mqa_logits(
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""
+ if current_platform.is_device_capability_family(120) and q[1] is None:
+ return _fp8_paged_mqa_logits_sm12x(
+ q, kv_cache, weights, context_lens, block_tables, max_model_len
+ )
_lazy_init()
if _fp8_fp4_paged_mqa_logits_impl is None:
return _missing()
@@ -457,6 +922,52 @@ def fp8_fp4_paged_mqa_logits(
)
+def _tf32_hc_prenorm_gemm_torch(
+ x: torch.Tensor,
+ fn: torch.Tensor,
+ out: torch.Tensor,
+ sqrsum: torch.Tensor,
+ num_split: int,
+) -> torch.Tensor:
+ """Portable SM12x HyperConnection prenorm GEMM fallback.
+
+ DeepGEMM's split ABI only requires that downstream consumers recover the
+ full result by summing over the split dimension. Keep the implementation
+ simple by writing the full product to split zero and clearing the rest.
+ """
+ del num_split
+ product = x.float() @ fn.float().T
+ norm = x.float().square().sum(dim=-1)
+
+ if out.dim() == 3:
+ out.zero_()
+ sqrsum.zero_()
+ out[0].copy_(product)
+ sqrsum[0].copy_(norm)
+ else:
+ out.copy_(product)
+ sqrsum.copy_(norm)
+ return out
+
+
+def _tf32_hc_prenorm_gemm_sm12x(
+ x: torch.Tensor,
+ fn: torch.Tensor,
+ out: torch.Tensor,
+ sqrsum: torch.Tensor,
+ num_split: int,
+) -> torch.Tensor:
+ if out.dim() == 3 and sqrsum.dim() == 2:
+ from vllm.model_executor.layers.deepseek_v4_triton_kernels import (
+ tf32_hc_prenorm_gemm_triton,
+ )
+
+ tf32_hc_prenorm_gemm_triton(x, fn, out, sqrsum, num_split)
+ return out
+
+ return _tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split)
+
+
def tf32_hc_prenorm_gemm(
x: torch.Tensor,
fn: torch.Tensor,
@@ -471,6 +982,8 @@ def tf32_hc_prenorm_gemm(
See the caller function for shape requirement
"""
+ if current_platform.is_device_capability_family(120):
+ return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split)
_lazy_init()
if _tf32_hc_prenorm_gemm_impl is None:
return _missing()
@@ -570,7 +1083,9 @@ def should_use_deepgemm_for_fp8_linear(
"m_grouped_fp8_fp4_gemm_nt_contiguous",
"fp8_m_grouped_gemm_nt_masked",
"fp8_fp4_mqa_logits",
+ "fp8_fp4_mqa_topk_indices",
"fp8_fp4_paged_mqa_logits",
+ "fp8_fp4_paged_mqa_topk_indices",
"get_paged_mqa_logits_metadata",
"per_block_cast_to_fp8",
"is_deep_gemm_e8m0_used",
diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py
index 474a5b2d421e..b2399c683eb5 100644
--- a/vllm/v1/attention/backends/mla/flashmla_sparse.py
+++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py
@@ -30,6 +30,10 @@
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping
+from vllm.v1.attention.backends.mla.sparse_mla_env import (
+ is_triton_sparse_mla_enabled_for_platform,
+ triton_sparse_mla_cudagraphs_allowed,
+)
from vllm.v1.attention.backends.mla.sparse_utils import (
triton_convert_req_index_to_global_index,
)
@@ -266,6 +270,20 @@ def get_prefill_workspace_size(max_model_len: int):
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
+ @classmethod
+ def get_cudagraph_support(
+ cls,
+ vllm_config: VllmConfig,
+ kv_cache_spec: AttentionSpec,
+ ) -> AttentionCGSupport:
+ if (
+ getattr(kv_cache_spec, "model_version", None) == "deepseek_v4"
+ and is_triton_sparse_mla_enabled_for_platform()
+ and not triton_sparse_mla_cudagraphs_allowed(vllm_config)
+ ):
+ return AttentionCGSupport.NEVER
+ return cls._cudagraph_support
+
def __init__(
self,
kv_cache_spec: AttentionSpec,
diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py
index 7c0715a9e8b6..3f43b65d26fb 100644
--- a/vllm/v1/attention/backends/mla/indexer.py
+++ b/vllm/v1/attention/backends/mla/indexer.py
@@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+import os
from dataclasses import dataclass
import torch
-import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -32,6 +32,28 @@
logger = init_logger(__name__)
+def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int:
+ configured_mb = os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB")
+ if configured_mb is not None:
+ return int(configured_mb) * 1024 * 1024
+
+ if is_sm12x is None:
+ is_sm12x = (
+ current_platform.is_cuda()
+ and current_platform.is_device_capability_family(120)
+ )
+ default_mb = 256 if is_sm12x else 512
+ return default_mb * 1024 * 1024
+
+
+def _uses_deep_gemm_scheduler_metadata() -> bool:
+ return (
+ current_platform.is_cuda()
+ and has_deep_gemm()
+ and not current_platform.is_device_capability_family(120)
+ )
+
+
@triton.jit
def _prepare_uniform_decode_kernel(
seq_lens_ptr,
@@ -269,13 +291,12 @@ def __init__(self, *args, **kwargs):
self.reorder_batch_threshold += self.num_speculative_tokens
# NOTE(zyongye) fp4 indexer cache only natively supports next_n in
# natively_supported_next_n_fp4; for other next_n values we fall back
- # to the flattening path. Outside the SM100 datacenter family the FP8
- # paged MQA logits kernel has the same [1, 2] constraint (deepgemm
- # smxx_fp8_fp4_paged_mqa_logits.hpp:233), so flatten there too.
+ # to the flattening path. When fp4 indexer cache is disabled, the
+ # native (non-flattening) path handles all next_n values.
self.use_flattening = (
self.use_fp4_indexer_cache
- or not current_platform.is_device_capability_family(100)
- ) and next_n not in self.natively_supported_next_n_fp4
+ and next_n not in self.natively_supported_next_n_fp4
+ )
sm_count = num_compute_units(self.device.index)
self.num_sms = sm_count
@@ -520,7 +541,7 @@ def build(
prefill_query_lens_cpu = torch.diff(
query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1]
)
- max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024
+ max_logits_bytes = sparse_indexer_max_logits_bytes()
# Upper bound is exact for prefill rows (the `[num_decodes:]`
# slice below).
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
@@ -610,8 +631,7 @@ def build(
if seq_lens.dim() == 1:
seq_lens = seq_lens.unsqueeze(-1)
- # DeepGEMM is required for the paged MQA logits on CUDA devices
- if current_platform.is_cuda() and has_deep_gemm():
+ if _uses_deep_gemm_scheduler_metadata():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens,
self.kv_cache_spec.storage_block_size,
diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py
new file mode 100644
index 000000000000..52af38e4cea3
--- /dev/null
+++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py
@@ -0,0 +1,150 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Environment controls for the portable Triton sparse MLA path."""
+
+import os
+
+import torch
+
+from vllm.logger import init_logger
+from vllm.platforms import current_platform
+
+_TRITON_MLA_SPARSE_ENV = "VLLM_TRITON_MLA_SPARSE"
+_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE"
+_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE"
+_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV = (
+ "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH"
+)
+_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV = "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE"
+_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV = "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE"
+
+_ENV_TRUE_VALUES = {"1", "true", "yes", "on"}
+_ENV_FALSE_VALUES = {"0", "false", "no", "off"}
+
+logger = init_logger(__name__)
+
+
+def _optional_env_flag(name: str) -> bool | None:
+ raw_value = os.getenv(name)
+ if raw_value is None:
+ return None
+ value = raw_value.lower()
+ if value in _ENV_TRUE_VALUES:
+ return True
+ if value in _ENV_FALSE_VALUES:
+ return False
+ return None
+
+
+def _is_sm12x_device(device: torch.device) -> bool:
+ if not torch.cuda.is_available():
+ return False
+ index = device.index if device.index is not None else torch.cuda.current_device()
+ return torch.cuda.get_device_capability(index)[0] == 12
+
+
+def triton_sparse_mla_configured() -> bool | None:
+ return _optional_env_flag(_TRITON_MLA_SPARSE_ENV)
+
+
+def is_triton_sparse_mla_enabled_for_platform() -> bool:
+ configured = triton_sparse_mla_configured()
+ if configured is not None:
+ return configured
+ return current_platform.is_device_capability_family(120)
+
+
+def is_triton_sparse_mla_enabled(device: torch.device) -> bool:
+ configured = triton_sparse_mla_configured()
+ if configured is not None:
+ return configured
+ return _is_sm12x_device(device)
+
+
+def _uses_speculative_decoding(vllm_config) -> bool:
+ return bool(getattr(vllm_config, "speculative_config", None))
+
+
+def triton_sparse_mla_cudagraphs_allowed(vllm_config=None) -> bool:
+ configured = _optional_env_flag(_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV)
+ if configured is not None:
+ return configured
+ return not (
+ vllm_config is not None and _uses_speculative_decoding(vllm_config)
+ )
+
+
+def disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) -> None:
+ if not is_triton_sparse_mla_enabled_for_platform():
+ return
+ if triton_sparse_mla_cudagraphs_allowed(vllm_config):
+ logger.warning_once(
+ "Keeping vLLM compile and CUDA graphs enabled for the DeepSeek V4 "
+ "Triton sparse MLA path because "
+ f"{_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV}=1 or speculative "
+ "decoding is not configured. This is an "
+ "experimental performance mode."
+ )
+ return
+
+ from vllm.config.compilation import CompilationMode, CUDAGraphMode
+
+ compilation_config = vllm_config.compilation_config
+ if (
+ compilation_config.mode == CompilationMode.NONE
+ and compilation_config.cudagraph_mode == CUDAGraphMode.NONE
+ ):
+ return
+
+ logger.warning_once(
+ "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton "
+ "sparse MLA path because the current Triton sparse MLA path is not "
+ "compile/graph-safe yet, or because speculative decoding uses "
+ "multi-token sparse MLA decode."
+ )
+ compilation_config.mode = CompilationMode.NONE
+ compilation_config.compile_sizes = []
+ compilation_config.compile_ranges_endpoints = []
+ compilation_config.cudagraph_mode = CUDAGraphMode.NONE
+ compilation_config.cudagraph_capture_sizes = []
+ compilation_config.max_cudagraph_capture_size = 0
+
+
+def triton_sparse_mla_topk_chunk_size() -> int:
+ raw_value = os.getenv(_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV)
+ if raw_value is None:
+ return 512
+ try:
+ return max(1, int(raw_value))
+ except ValueError:
+ return 512
+
+
+def triton_sparse_mla_query_chunk_size() -> int:
+ raw_value = os.getenv(_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV)
+ if raw_value is None:
+ return 256
+ try:
+ return max(1, int(raw_value))
+ except ValueError:
+ return 256
+
+
+def triton_sparse_mla_head_block_size() -> int | None:
+ raw_value = os.getenv(_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV)
+ if raw_value is None:
+ return None
+ try:
+ value = int(raw_value)
+ except ValueError:
+ return None
+ if value in (1, 2, 4):
+ return value
+ return None
+
+
+def triton_sparse_mla_matmul_decode_enabled() -> bool:
+ configured = _optional_env_flag(_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV)
+ if configured is not None:
+ return configured
+ return current_platform.is_device_capability_family(120)
diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py
new file mode 100644
index 000000000000..834ecda43032
--- /dev/null
+++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py
@@ -0,0 +1,2694 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Portable sparse MLA Triton kernels."""
+
+import torch
+
+from vllm.triton_utils import tl, triton
+from vllm.v1.attention.backends.mla.sparse_mla_env import (
+ triton_sparse_mla_head_block_size,
+)
+
+
+def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int:
+ """Choose the SM12x sparse MLA head grouping for decode kernels.
+
+ Single-token decode is latency sensitive and does best with one head per
+ program. Once there are enough query tokens, grouping heads lets the kernel
+ reuse each dequantized KV row across multiple heads.
+ """
+
+ configured_head_block_size = triton_sparse_mla_head_block_size()
+ if configured_head_block_size is not None:
+ return configured_head_block_size
+ if num_decode_tokens <= 4:
+ return 1
+ if num_decode_tokens < 16:
+ return 2
+ return 4
+
+
+@triton.jit
+def _merge_two_subsets_with_sink_kernel(
+ out0_ptr,
+ lse0_ptr,
+ out1_ptr,
+ lse1_ptr,
+ sink_ptr,
+ output_ptr,
+ stride_out0_t: tl.constexpr,
+ stride_out0_h: tl.constexpr,
+ stride_out0_d: tl.constexpr,
+ stride_lse0_t: tl.constexpr,
+ stride_lse0_h: tl.constexpr,
+ stride_out1_t: tl.constexpr,
+ stride_out1_h: tl.constexpr,
+ stride_out1_d: tl.constexpr,
+ stride_lse1_t: tl.constexpr,
+ stride_lse1_h: tl.constexpr,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_head = tl.program_id(0)
+ block_d = tl.program_id(1)
+ token_idx = token_head // num_heads
+ head_idx = token_head - token_idx * num_heads
+ offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ mask = offsets < head_dim
+
+ lse0 = tl.load(lse0_ptr + token_idx * stride_lse0_t + head_idx * stride_lse0_h)
+ lse1 = tl.load(lse1_ptr + token_idx * stride_lse1_t + head_idx * stride_lse1_h)
+ sink = tl.load(sink_ptr + head_idx)
+ merge_max = tl.maximum(tl.maximum(lse0, lse1), sink)
+
+ weight0 = tl.exp(lse0 - merge_max)
+ weight1 = tl.exp(lse1 - merge_max)
+ weight_sink = tl.exp(sink - merge_max)
+ denom = weight0 + weight1 + weight_sink
+
+ out0 = tl.load(
+ out0_ptr
+ + token_idx * stride_out0_t
+ + head_idx * stride_out0_h
+ + offsets * stride_out0_d,
+ mask=mask,
+ other=0.0,
+ ).to(tl.float32)
+ out1 = tl.load(
+ out1_ptr
+ + token_idx * stride_out1_t
+ + head_idx * stride_out1_h
+ + offsets * stride_out1_d,
+ mask=mask,
+ other=0.0,
+ ).to(tl.float32)
+ merged = (out0 * weight0 + out1 * weight1) / denom
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_idx * stride_output_h
+ + offsets * stride_output_d,
+ merged,
+ mask=mask,
+ )
+
+
+def merge_two_sparse_mla_subsets_with_sink(
+ subset0_output: torch.Tensor,
+ subset0_lse: torch.Tensor,
+ subset1_output: torch.Tensor,
+ subset1_lse: torch.Tensor,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+) -> None:
+ assert subset0_output.shape == subset1_output.shape
+ assert subset0_output.shape == output.shape
+ assert subset0_lse.shape == subset1_lse.shape
+ assert subset0_lse.shape == subset0_output.shape[:2]
+ assert attn_sink.shape[0] == subset0_output.shape[1]
+ assert subset0_output.is_cuda
+ assert subset1_output.is_cuda
+ assert output.is_cuda
+
+ num_tokens, num_heads, head_dim = subset0_output.shape
+ block_d = min(128, triton.next_power_of_2(head_dim))
+ grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d))
+ _merge_two_subsets_with_sink_kernel[grid](
+ subset0_output,
+ subset0_lse,
+ subset1_output,
+ subset1_lse,
+ attn_sink,
+ output,
+ subset0_output.stride(0),
+ subset0_output.stride(1),
+ subset0_output.stride(2),
+ subset0_lse.stride(0),
+ subset0_lse.stride(1),
+ subset1_output.stride(0),
+ subset1_output.stride(1),
+ subset1_output.stride(2),
+ subset1_lse.stride(0),
+ subset1_lse.stride(1),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ num_heads,
+ head_dim,
+ BLOCK_D=block_d,
+ num_warps=4,
+ )
+
+
+@triton.jit
+def _merge_single_subset_with_sink_kernel(
+ subset_output_ptr,
+ subset_lse_ptr,
+ sink_ptr,
+ output_ptr,
+ stride_subset_t: tl.constexpr,
+ stride_subset_h: tl.constexpr,
+ stride_subset_d: tl.constexpr,
+ stride_lse_t: tl.constexpr,
+ stride_lse_h: tl.constexpr,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_head = tl.program_id(0)
+ block_d = tl.program_id(1)
+ token_idx = token_head // num_heads
+ head_idx = token_head - token_idx * num_heads
+ offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ mask = offsets < head_dim
+
+ subset_lse = tl.load(
+ subset_lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h
+ )
+ sink = tl.load(sink_ptr + head_idx)
+ merge_max = tl.maximum(subset_lse, sink)
+
+ subset_weight = tl.exp(subset_lse - merge_max)
+ sink_weight = tl.exp(sink - merge_max)
+ denom = subset_weight + sink_weight
+ subset_output = tl.load(
+ subset_output_ptr
+ + token_idx * stride_subset_t
+ + head_idx * stride_subset_h
+ + offsets * stride_subset_d,
+ mask=mask,
+ other=0.0,
+ ).to(tl.float32)
+ merged = subset_output * subset_weight / denom
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_idx * stride_output_h
+ + offsets * stride_output_d,
+ merged,
+ mask=mask,
+ )
+
+
+def merge_sparse_mla_subset_with_sink(
+ subset_output: torch.Tensor,
+ subset_lse: torch.Tensor,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+) -> None:
+ assert subset_output.shape == output.shape
+ assert subset_lse.shape == subset_output.shape[:2]
+ assert attn_sink.shape[0] == subset_output.shape[1]
+ assert subset_output.is_cuda
+ assert subset_lse.is_cuda
+ assert attn_sink.is_cuda
+ assert output.is_cuda
+
+ num_tokens, num_heads, head_dim = subset_output.shape
+ block_d = min(128, triton.next_power_of_2(head_dim))
+ grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d))
+ _merge_single_subset_with_sink_kernel[grid](
+ subset_output,
+ subset_lse,
+ attn_sink,
+ output,
+ subset_output.stride(0),
+ subset_output.stride(1),
+ subset_output.stride(2),
+ subset_lse.stride(0),
+ subset_lse.stride(1),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ num_heads,
+ head_dim,
+ BLOCK_D=block_d,
+ num_warps=4,
+ )
+
+
+@triton.jit
+def _build_combined_decode_valid_mask_kernel(
+ output_ptr,
+ slot_ids_ptr,
+ topk_lens_ptr,
+ swa_lens_ptr,
+ stride_output_t: tl.constexpr,
+ stride_output_c: tl.constexpr,
+ stride_slot_t: tl.constexpr,
+ stride_slot_c: tl.constexpr,
+ num_compressed_candidates: tl.constexpr,
+ num_candidates: tl.constexpr,
+ BLOCK_C: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ offsets = tl.arange(0, BLOCK_C)
+ candidate_mask = offsets < num_candidates
+
+ topk_lens = tl.load(topk_lens_ptr + token_idx)
+ swa_lens = tl.load(swa_lens_ptr + token_idx)
+ is_compressed = offsets < num_compressed_candidates
+ swa_offsets = offsets - num_compressed_candidates
+ slot_ids = tl.load(
+ slot_ids_ptr + token_idx * stride_slot_t + offsets * stride_slot_c,
+ mask=is_compressed,
+ other=-1,
+ )
+ valid_compressed = is_compressed & (offsets < topk_lens) & (slot_ids >= 0)
+ valid_swa = (~is_compressed) & (swa_offsets < swa_lens)
+ valid = valid_compressed | valid_swa
+ tl.store(
+ output_ptr + token_idx * stride_output_t + offsets * stride_output_c,
+ valid,
+ mask=candidate_mask,
+ )
+
+
+def build_combined_sparse_mla_decode_valid_mask(
+ output: torch.Tensor,
+ compressed_slot_ids: torch.Tensor,
+ topk_lens: torch.Tensor,
+ swa_lens: torch.Tensor,
+) -> None:
+ """Build `[compressed, SWA]` validity mask for SM12x decode."""
+ if compressed_slot_ids.dim() == 3:
+ assert compressed_slot_ids.shape[1] == 1
+ compressed_slot_ids = compressed_slot_ids[:, 0, :]
+
+ assert output.dim() == 2
+ assert output.dtype == torch.bool
+ assert compressed_slot_ids.dim() == 2
+ assert output.shape[0] == compressed_slot_ids.shape[0]
+ assert output.shape[0] == topk_lens.shape[0]
+ assert output.shape[0] == swa_lens.shape[0]
+ assert output.shape[1] >= compressed_slot_ids.shape[1]
+ assert output.is_cuda
+ assert compressed_slot_ids.is_cuda
+ assert topk_lens.is_cuda
+ assert swa_lens.is_cuda
+
+ num_candidates = output.shape[1]
+ block_c = triton.next_power_of_2(num_candidates)
+ _build_combined_decode_valid_mask_kernel[(output.shape[0],)](
+ output,
+ compressed_slot_ids,
+ topk_lens,
+ swa_lens,
+ output.stride(0),
+ output.stride(1),
+ compressed_slot_ids.stride(0),
+ compressed_slot_ids.stride(1),
+ compressed_slot_ids.shape[1],
+ num_candidates,
+ BLOCK_C=block_c,
+ num_warps=4,
+ )
+
+
+def matmul_sparse_mla_attention_with_sink(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int | None = None,
+ score_buffer: torch.Tensor | None = None,
+ head_block_size: int = 1,
+ value_block_size: int | None = None,
+ candidate_block_size: int | None = None,
+) -> None:
+ """Compute sink-aware sparse MLA over materialized BF16 KV.
+
+ This path intentionally dequantizes/gathers KV once, computes scores with
+ batched matrix multiplication, and finishes the sink-aware value reduction
+ in Triton. It is useful for the SM12x decode path where the direct Triton
+ kernel otherwise repeats fp8_ds_mla dequantization once per head group.
+ """
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}"
+ assert valid_tokens.shape == kv.shape[:2]
+ assert q.shape[0] == kv.shape[0]
+ assert q.shape[-1] == kv.shape[-1]
+ assert output.shape[0] == q.shape[0]
+ assert output.shape[2] == q.shape[-1]
+ assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda
+ assert attn_sink.is_cuda and output.is_cuda
+
+ active_heads = num_heads if num_heads is not None else output.shape[1]
+ assert active_heads <= q.shape[1]
+ assert active_heads <= output.shape[1]
+ assert active_heads <= attn_sink.shape[0]
+
+ q_active = q[:, :active_heads]
+ num_tokens = q.shape[0]
+ num_candidates = kv.shape[1]
+ if score_buffer is None:
+ score_buffer = torch.empty(
+ (num_tokens, active_heads, num_candidates),
+ dtype=torch.float32,
+ device=q.device,
+ )
+ assert score_buffer.shape == (num_tokens, active_heads, num_candidates)
+ assert score_buffer.device == q.device
+ assert score_buffer.dtype in (torch.float32, torch.bfloat16)
+ if score_buffer.dtype == torch.float32:
+ q_score = q_active.float()
+ kv_score = kv.float()
+ else:
+ q_score = q_active.to(score_buffer.dtype)
+ kv_score = kv.to(score_buffer.dtype)
+ torch.bmm(q_score, kv_score.transpose(1, 2), out=score_buffer)
+ score_buffer.mul_(scale)
+ finish_materialized_sparse_mla_scores_with_sink(
+ score_buffer,
+ kv,
+ valid_tokens,
+ attn_sink,
+ output,
+ num_heads=active_heads,
+ head_block_size=head_block_size,
+ value_block_size=value_block_size,
+ candidate_block_size=candidate_block_size,
+ )
+
+
+@triton.jit
+def _finish_materialized_scores_with_sink_kernel(
+ scores_ptr,
+ kv_ptr,
+ valid_tokens_ptr,
+ attn_sink_ptr,
+ output_ptr,
+ stride_scores_t: tl.constexpr,
+ stride_scores_h: tl.constexpr,
+ stride_scores_c: tl.constexpr,
+ stride_kv_t: tl.constexpr,
+ stride_kv_c: tl.constexpr,
+ stride_kv_d: tl.constexpr,
+ stride_valid_t: tl.constexpr,
+ stride_valid_c: tl.constexpr,
+ stride_out_t: tl.constexpr,
+ stride_out_h: tl.constexpr,
+ stride_out_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates: tl.constexpr,
+ HEAD_BLOCK: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_block_idx = tl.program_id(1)
+ head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)
+ dim_offsets = tl.arange(0, BLOCK_D)
+ head_mask = head_offsets < num_heads
+ dim_mask = dim_offsets < head_dim
+ matrix_mask = head_mask[:, None] & dim_mask[None, :]
+
+ running_max = tl.load(attn_sink_ptr + head_offsets, mask=head_mask, other=0.0).to(
+ tl.float32
+ )
+ running_denom = tl.full((HEAD_BLOCK,), 1.0, tl.float32)
+ running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32)
+
+ for candidate_idx in range(0, num_candidates):
+ is_valid = tl.load(
+ valid_tokens_ptr
+ + token_idx * stride_valid_t
+ + candidate_idx * stride_valid_c
+ )
+ if is_valid:
+ score = tl.load(
+ scores_ptr
+ + token_idx * stride_scores_t
+ + head_offsets * stride_scores_h
+ + candidate_idx * stride_scores_c,
+ mask=head_mask,
+ other=-float("inf"),
+ ).to(tl.float32)
+ kv = tl.load(
+ kv_ptr
+ + token_idx * stride_kv_t
+ + candidate_idx * stride_kv_c
+ + dim_offsets * stride_kv_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = (
+ running_acc * previous_weight[:, None]
+ + kv[None, :] * candidate_weight[:, None]
+ )
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ result = running_acc / running_denom[:, None]
+ tl.store(
+ output_ptr
+ + token_idx * stride_out_t
+ + head_offsets[:, None] * stride_out_h
+ + dim_offsets[None, :] * stride_out_d,
+ result,
+ mask=matrix_mask,
+ )
+
+
+@triton.jit
+def _finish_materialized_scores_with_sink_candidate_block_kernel(
+ scores_ptr,
+ kv_ptr,
+ valid_tokens_ptr,
+ attn_sink_ptr,
+ output_ptr,
+ stride_scores_t: tl.constexpr,
+ stride_scores_h: tl.constexpr,
+ stride_scores_c: tl.constexpr,
+ stride_kv_t: tl.constexpr,
+ stride_kv_c: tl.constexpr,
+ stride_kv_d: tl.constexpr,
+ stride_valid_t: tl.constexpr,
+ stride_valid_c: tl.constexpr,
+ stride_out_t: tl.constexpr,
+ stride_out_h: tl.constexpr,
+ stride_out_d: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+ dim_block_idx = tl.program_id(2)
+ candidate_offsets = tl.arange(0, BLOCK_K)
+ dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
+ dim_mask = dim_offsets < head_dim
+
+ max_score = tl.load(attn_sink_ptr + head_idx).to(tl.float32)
+ for candidate_start in range(0, num_candidates, BLOCK_K):
+ candidates = candidate_start + candidate_offsets
+ candidate_mask = candidates < num_candidates
+ is_valid = tl.load(
+ valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c,
+ mask=candidate_mask,
+ other=0,
+ ).to(tl.int1)
+ scores = tl.load(
+ scores_ptr
+ + token_idx * stride_scores_t
+ + head_idx * stride_scores_h
+ + candidates * stride_scores_c,
+ mask=candidate_mask & is_valid,
+ other=-float("inf"),
+ ).to(tl.float32)
+ max_score = tl.maximum(max_score, tl.max(scores, axis=0))
+
+ denom = tl.exp(tl.load(attn_sink_ptr + head_idx).to(tl.float32) - max_score)
+ acc = tl.zeros((BLOCK_D,), tl.float32)
+ for candidate_start in range(0, num_candidates, BLOCK_K):
+ candidates = candidate_start + candidate_offsets
+ candidate_mask = candidates < num_candidates
+ is_valid = tl.load(
+ valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c,
+ mask=candidate_mask,
+ other=0,
+ ).to(tl.int1)
+ scores = tl.load(
+ scores_ptr
+ + token_idx * stride_scores_t
+ + head_idx * stride_scores_h
+ + candidates * stride_scores_c,
+ mask=candidate_mask & is_valid,
+ other=-float("inf"),
+ ).to(tl.float32)
+ weights = tl.exp(scores - max_score)
+ denom += tl.sum(weights, axis=0)
+ kv = tl.load(
+ kv_ptr
+ + token_idx * stride_kv_t
+ + candidates[:, None] * stride_kv_c
+ + dim_offsets[None, :] * stride_kv_d,
+ mask=(candidate_mask & is_valid)[:, None] & dim_mask[None, :],
+ other=0.0,
+ )
+ acc += tl.sum(kv.to(tl.float32) * weights[:, None], axis=0)
+
+ tl.store(
+ output_ptr
+ + token_idx * stride_out_t
+ + head_idx * stride_out_h
+ + dim_offsets * stride_out_d,
+ acc / denom,
+ mask=dim_mask,
+ )
+
+
+@triton.jit
+def _finish_materialized_scores_with_sink_value_block_kernel(
+ scores_ptr,
+ kv_ptr,
+ valid_tokens_ptr,
+ attn_sink_ptr,
+ output_ptr,
+ stride_scores_t: tl.constexpr,
+ stride_scores_h: tl.constexpr,
+ stride_scores_c: tl.constexpr,
+ stride_kv_t: tl.constexpr,
+ stride_kv_c: tl.constexpr,
+ stride_kv_d: tl.constexpr,
+ stride_valid_t: tl.constexpr,
+ stride_valid_c: tl.constexpr,
+ stride_out_t: tl.constexpr,
+ stride_out_h: tl.constexpr,
+ stride_out_d: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+ dim_block_idx = tl.program_id(2)
+ dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D)
+ dim_mask = dim_offsets < head_dim
+
+ running_max = tl.load(attn_sink_ptr + head_idx).to(tl.float32)
+ running_denom = tl.full((), 1.0, tl.float32)
+ running_acc = tl.zeros((BLOCK_D,), tl.float32)
+
+ for candidate_idx in range(0, num_candidates):
+ is_valid = tl.load(
+ valid_tokens_ptr
+ + token_idx * stride_valid_t
+ + candidate_idx * stride_valid_c
+ )
+ if is_valid:
+ score = tl.load(
+ scores_ptr
+ + token_idx * stride_scores_t
+ + head_idx * stride_scores_h
+ + candidate_idx * stride_scores_c
+ ).to(tl.float32)
+ kv = tl.load(
+ kv_ptr
+ + token_idx * stride_kv_t
+ + candidate_idx * stride_kv_c
+ + dim_offsets * stride_kv_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = running_acc * previous_weight + kv * candidate_weight
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ result = running_acc / running_denom
+ tl.store(
+ output_ptr
+ + token_idx * stride_out_t
+ + head_idx * stride_out_h
+ + dim_offsets * stride_out_d,
+ result,
+ mask=dim_mask,
+ )
+
+
+def finish_materialized_sparse_mla_scores_with_sink(
+ scores: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+ num_heads: int | None = None,
+ head_block_size: int = 1,
+ value_block_size: int | None = None,
+ candidate_block_size: int | None = None,
+) -> None:
+ assert scores.dim() == 3
+ assert kv.dim() == 3
+ assert valid_tokens.shape == kv.shape[:2]
+ assert scores.shape[0] == kv.shape[0]
+ assert scores.shape[2] == kv.shape[1]
+ assert output.shape[0] == kv.shape[0]
+ assert output.shape[2] == kv.shape[2]
+ assert scores.dtype in (torch.float32, torch.bfloat16)
+ assert head_block_size in (1, 2, 4)
+ if value_block_size is not None:
+ assert value_block_size in (64, 128, 256, 512)
+ if candidate_block_size is not None:
+ assert candidate_block_size in (16, 32, 64, 128)
+ assert scores.is_cuda and kv.is_cuda and valid_tokens.is_cuda
+ assert attn_sink.is_cuda and output.is_cuda
+
+ active_heads = num_heads if num_heads is not None else output.shape[1]
+ assert active_heads <= scores.shape[1]
+ assert active_heads <= output.shape[1]
+ assert active_heads <= attn_sink.shape[0]
+
+ num_tokens, _, num_candidates = scores.shape
+ head_dim = kv.shape[2]
+ if candidate_block_size is not None:
+ block_d = value_block_size if value_block_size is not None else 128
+ candidate_grid = (num_tokens, active_heads, triton.cdiv(head_dim, block_d))
+ _finish_materialized_scores_with_sink_candidate_block_kernel[candidate_grid](
+ scores,
+ kv,
+ valid_tokens,
+ attn_sink,
+ output,
+ scores.stride(0),
+ scores.stride(1),
+ scores.stride(2),
+ kv.stride(0),
+ kv.stride(1),
+ kv.stride(2),
+ valid_tokens.stride(0),
+ valid_tokens.stride(1),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ head_dim,
+ num_candidates,
+ BLOCK_K=candidate_block_size,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+ if output.shape[1] > active_heads:
+ output[:, active_heads:].zero_()
+ return
+
+ if value_block_size is not None and value_block_size < head_dim:
+ value_grid = (
+ num_tokens,
+ active_heads,
+ triton.cdiv(head_dim, value_block_size),
+ )
+ _finish_materialized_scores_with_sink_value_block_kernel[value_grid](
+ scores,
+ kv,
+ valid_tokens,
+ attn_sink,
+ output,
+ scores.stride(0),
+ scores.stride(1),
+ scores.stride(2),
+ kv.stride(0),
+ kv.stride(1),
+ kv.stride(2),
+ valid_tokens.stride(0),
+ valid_tokens.stride(1),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ head_dim,
+ num_candidates,
+ BLOCK_D=value_block_size,
+ num_warps=4,
+ )
+ if output.shape[1] > active_heads:
+ output[:, active_heads:].zero_()
+ return
+
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ head_grid = (num_tokens, triton.cdiv(active_heads, head_block_size))
+ _finish_materialized_scores_with_sink_kernel[head_grid](
+ scores,
+ kv,
+ valid_tokens,
+ attn_sink,
+ output,
+ scores.stride(0),
+ scores.stride(1),
+ scores.stride(2),
+ kv.stride(0),
+ kv.stride(1),
+ kv.stride(2),
+ valid_tokens.stride(0),
+ valid_tokens.stride(1),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ active_heads,
+ head_dim,
+ num_candidates,
+ HEAD_BLOCK=head_block_size,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+ if output.shape[1] > active_heads:
+ output[:, active_heads:].zero_()
+
+
+@triton.jit
+def _accumulate_gathered_attention_chunk_kernel(
+ q_ptr,
+ kv_ptr,
+ slot_ids_ptr,
+ lens_ptr,
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_kv_t: tl.constexpr,
+ stride_kv_c: tl.constexpr,
+ stride_kv_d: tl.constexpr,
+ stride_slot_t: tl.constexpr,
+ stride_slot_c: tl.constexpr,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates,
+ candidate_offset,
+ scale: tl.constexpr,
+ HAS_SLOT_IDS: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+ offsets = tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ q = tl.load(
+ q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ state_offset = token_idx * stride_state_t + head_idx * stride_state_h
+ acc_offset = (
+ token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d
+ )
+ running_max = tl.load(max_score_ptr + state_offset)
+ running_denom = tl.load(denom_ptr + state_offset)
+ running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
+ valid_len = tl.load(lens_ptr + token_idx)
+
+ for candidate_idx in range(0, num_candidates):
+ is_valid = (candidate_offset + candidate_idx) < valid_len
+ if HAS_SLOT_IDS:
+ slot_id = tl.load(
+ slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c
+ )
+ is_valid = is_valid & (slot_id >= 0)
+
+ if is_valid:
+ kv = tl.load(
+ kv_ptr
+ + token_idx * stride_kv_t
+ + candidate_idx * stride_kv_c
+ + offsets * stride_kv_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ score = tl.sum(q * kv, axis=0) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = running_acc * previous_weight + kv * candidate_weight
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ tl.store(max_score_ptr + state_offset, running_max)
+ tl.store(denom_ptr + state_offset, running_denom)
+ tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask)
+
+
+def accumulate_gathered_sparse_mla_attention_chunk(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ lens: torch.Tensor,
+ scale: float,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ candidate_offset: int = 0,
+ slot_ids: torch.Tensor | None = None,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}"
+ assert q.shape[0] == kv.shape[0]
+ assert q.shape[-1] == kv.shape[-1]
+ assert lens.shape[0] == q.shape[0]
+ assert max_score.shape[0] == q.shape[0]
+ assert max_score.shape[1] <= q.shape[1]
+ assert denom.shape == max_score.shape
+ assert acc.shape == (*max_score.shape, q.shape[-1])
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert q.is_cuda and kv.is_cuda and lens.is_cuda
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+
+ if slot_ids is not None:
+ if slot_ids.dim() == 3:
+ assert slot_ids.shape[1] == 1
+ slot_ids = slot_ids[:, 0]
+ assert slot_ids.dim() == 2
+ assert slot_ids.shape == kv.shape[:2]
+ assert slot_ids.is_cuda
+
+ num_tokens, _, head_dim = q.shape
+ num_heads = max_score.shape[1]
+ num_candidates = kv.shape[1]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, num_heads)
+ _accumulate_gathered_attention_chunk_kernel[grid](
+ q,
+ kv,
+ slot_ids,
+ lens,
+ max_score,
+ denom,
+ acc,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ kv.stride(0),
+ kv.stride(1),
+ kv.stride(2),
+ slot_ids.stride(0) if slot_ids is not None else 0,
+ slot_ids.stride(1) if slot_ids is not None else 0,
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ num_heads,
+ head_dim,
+ num_candidates,
+ candidate_offset,
+ scale,
+ HAS_SLOT_IDS=slot_ids is not None,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _accumulate_indexed_attention_chunk_kernel(
+ q_ptr,
+ kv_flat_ptr,
+ indices_ptr,
+ lens_ptr,
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_kv_t,
+ stride_kv_d: tl.constexpr,
+ stride_indices_t: tl.constexpr,
+ stride_indices_c: tl.constexpr,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates,
+ candidate_offset,
+ scale: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+ offsets = tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ q = tl.load(
+ q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ state_offset = token_idx * stride_state_t + head_idx * stride_state_h
+ acc_offset = (
+ token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d
+ )
+ running_max = tl.load(max_score_ptr + state_offset)
+ running_denom = tl.load(denom_ptr + state_offset)
+ running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
+ valid_len = tl.load(lens_ptr + token_idx)
+
+ for candidate_idx in range(0, num_candidates):
+ kv_index = tl.load(
+ indices_ptr
+ + token_idx * stride_indices_t
+ + candidate_idx * stride_indices_c
+ )
+ is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0)
+
+ if is_valid:
+ kv = tl.load(
+ kv_flat_ptr
+ + kv_index.to(tl.int64) * stride_kv_t
+ + offsets * stride_kv_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ score = tl.sum(q * kv, axis=0) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = running_acc * previous_weight + kv * candidate_weight
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ tl.store(max_score_ptr + state_offset, running_max)
+ tl.store(denom_ptr + state_offset, running_denom)
+ tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask)
+
+
+def accumulate_indexed_sparse_mla_attention_chunk(
+ q: torch.Tensor,
+ kv_flat: torch.Tensor,
+ indices: torch.Tensor,
+ lens: torch.Tensor,
+ scale: float,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ candidate_offset: int = 0,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert kv_flat.dim() == 2
+ assert indices.dim() == 2
+ assert indices.shape[0] == q.shape[0]
+ assert kv_flat.shape[-1] == q.shape[-1]
+ assert lens.shape[0] == q.shape[0]
+ assert max_score.shape[0] == q.shape[0]
+ assert max_score.shape[1] <= q.shape[1]
+ assert denom.shape == max_score.shape
+ assert acc.shape == (*max_score.shape, q.shape[-1])
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+
+ num_tokens, _, head_dim = q.shape
+ num_heads = max_score.shape[1]
+ num_candidates = indices.shape[1]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, num_heads)
+ _accumulate_indexed_attention_chunk_kernel[grid](
+ q,
+ kv_flat,
+ indices,
+ lens,
+ max_score,
+ denom,
+ acc,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ kv_flat.stride(0),
+ kv_flat.stride(1),
+ indices.stride(0),
+ indices.stride(1),
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ num_heads,
+ head_dim,
+ num_candidates,
+ candidate_offset,
+ scale,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _accumulate_fp8ds_global_slots_attention_chunk_kernel(
+ q_ptr,
+ k_cache_ptr,
+ slot_ids_ptr,
+ lens_ptr,
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_slot_t: tl.constexpr,
+ stride_slot_c: tl.constexpr,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ cache_block_size: tl.constexpr,
+ token_data_size: tl.constexpr,
+ block_stride: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates,
+ candidate_offset,
+ scale: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+ offsets = tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ q = tl.load(
+ q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ state_offset = token_idx * stride_state_t + head_idx * stride_state_h
+ acc_offset = (
+ token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d
+ )
+ running_max = tl.load(max_score_ptr + state_offset)
+ running_denom = tl.load(denom_ptr + state_offset)
+ running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
+ valid_len = tl.load(lens_ptr + token_idx)
+
+ fp8_mask = offsets < fp8_dim
+ rope_mask = (offsets >= fp8_dim) & dim_mask
+ rope_offsets = tl.maximum(offsets - fp8_dim, 0)
+
+ for candidate_idx in range(0, num_candidates):
+ slot_id = tl.load(
+ slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c
+ )
+ is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0)
+
+ if is_valid:
+ block_idx = slot_id // cache_block_size
+ pos_in_block = slot_id % cache_block_size
+ cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv, axis=0) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = running_acc * previous_weight + kv * candidate_weight
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ tl.store(max_score_ptr + state_offset, running_max)
+ tl.store(denom_ptr + state_offset, running_denom)
+ tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask)
+
+
+def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ slot_ids: torch.Tensor,
+ lens: torch.Tensor,
+ block_size: int,
+ scale: float,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ candidate_offset: int = 0,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+ if slot_ids.dim() == 3:
+ assert slot_ids.shape[1] == 1
+ slot_ids = slot_ids[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert q.shape[-1] == 512
+ assert slot_ids.dim() == 2
+ assert slot_ids.shape[0] == q.shape[0]
+ assert lens.shape[0] == q.shape[0]
+ assert max_score.shape[0] == q.shape[0]
+ assert max_score.shape[1] <= q.shape[1]
+ assert denom.shape == max_score.shape
+ assert acc.shape == (*max_score.shape, q.shape[-1])
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert k_cache.dtype == torch.uint8
+ assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+
+ token_fp8_dim = 448
+ token_bf16_dim = 64
+ token_scale_dim = 8
+ quant_block_size = 64
+ token_data_size = token_fp8_dim + token_bf16_dim * 2
+
+ num_tokens, _, head_dim = q.shape
+ num_heads = max_score.shape[1]
+ num_candidates = slot_ids.shape[1]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, num_heads)
+ _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid](
+ q,
+ k_cache,
+ slot_ids,
+ lens,
+ max_score,
+ denom,
+ acc,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ slot_ids.stride(0),
+ slot_ids.stride(1),
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ block_size,
+ token_data_size,
+ k_cache.stride(0),
+ token_fp8_dim,
+ token_scale_dim,
+ quant_block_size,
+ num_heads,
+ head_dim,
+ num_candidates,
+ candidate_offset,
+ scale,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel(
+ q_ptr,
+ k_cache_ptr,
+ slot_ids_ptr,
+ lens_ptr,
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_slot_t: tl.constexpr,
+ stride_slot_c: tl.constexpr,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ cache_block_size: tl.constexpr,
+ token_data_size: tl.constexpr,
+ block_stride: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates,
+ candidate_offset,
+ scale: tl.constexpr,
+ HEAD_BLOCK: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_block_idx = tl.program_id(1)
+ head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)
+ dim_offsets = tl.arange(0, BLOCK_D)
+ head_mask = head_offsets < num_heads
+ dim_mask = dim_offsets < head_dim
+ matrix_mask = head_mask[:, None] & dim_mask[None, :]
+
+ q = tl.load(
+ q_ptr
+ + token_idx * stride_q_t
+ + head_offsets[:, None] * stride_q_h
+ + dim_offsets[None, :] * stride_q_d,
+ mask=matrix_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h
+ acc_offsets = (
+ token_idx * stride_acc_t
+ + head_offsets[:, None] * stride_acc_h
+ + dim_offsets[None, :] * stride_acc_d
+ )
+ running_max = tl.load(
+ max_score_ptr + state_offsets,
+ mask=head_mask,
+ other=-float("inf"),
+ )
+ running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0)
+ running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to(
+ tl.float32
+ )
+ valid_len = tl.load(lens_ptr + token_idx)
+
+ fp8_mask = dim_offsets < fp8_dim
+ rope_mask = (dim_offsets >= fp8_dim) & dim_mask
+ rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
+
+ for candidate_idx in range(0, num_candidates):
+ slot_id = tl.load(
+ slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c
+ )
+ is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0)
+
+ if is_valid:
+ block_idx = slot_id // cache_block_size
+ pos_in_block = slot_id % cache_block_size
+ cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = dim_offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv[None, :], axis=1) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = (
+ running_acc * previous_weight[:, None]
+ + kv[None, :] * candidate_weight[:, None]
+ )
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask)
+ tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask)
+ tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask)
+
+
+def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ slot_ids: torch.Tensor,
+ lens: torch.Tensor,
+ block_size: int,
+ scale: float,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ candidate_offset: int = 0,
+ head_block_size: int = 2,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+ if slot_ids.dim() == 3:
+ assert slot_ids.shape[1] == 1
+ slot_ids = slot_ids[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert q.shape[-1] == 512
+ assert slot_ids.dim() == 2
+ assert slot_ids.shape[0] == q.shape[0]
+ assert lens.shape[0] == q.shape[0]
+ assert max_score.shape[0] == q.shape[0]
+ assert max_score.shape[1] <= q.shape[1]
+ assert denom.shape == max_score.shape
+ assert acc.shape == (*max_score.shape, q.shape[-1])
+ assert head_block_size in (1, 2, 4)
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert k_cache.dtype == torch.uint8
+ assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+
+ token_fp8_dim = 448
+ token_bf16_dim = 64
+ token_scale_dim = 8
+ quant_block_size = 64
+ token_data_size = token_fp8_dim + token_bf16_dim * 2
+
+ num_tokens, _, head_dim = q.shape
+ num_heads = max_score.shape[1]
+ num_candidates = slot_ids.shape[1]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, triton.cdiv(num_heads, head_block_size))
+ _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid](
+ q,
+ k_cache,
+ slot_ids,
+ lens,
+ max_score,
+ denom,
+ acc,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ slot_ids.stride(0),
+ slot_ids.stride(1),
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ block_size,
+ token_data_size,
+ k_cache.stride(0),
+ token_fp8_dim,
+ token_scale_dim,
+ quant_block_size,
+ num_heads,
+ head_dim,
+ num_candidates,
+ candidate_offset,
+ scale,
+ HEAD_BLOCK=head_block_size,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _accumulate_fp8ds_paged_attention_chunk_kernel(
+ q_ptr,
+ k_cache_ptr,
+ seq_lens_ptr,
+ gather_lens_ptr,
+ block_table_ptr,
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_block_table_t,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ cache_block_size: tl.constexpr,
+ token_data_size: tl.constexpr,
+ block_stride: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates,
+ candidate_offset,
+ scale: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_idx = tl.program_id(1)
+ offsets = tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ q = tl.load(
+ q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ state_offset = token_idx * stride_state_t + head_idx * stride_state_h
+ acc_offset = (
+ token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d
+ )
+ running_max = tl.load(max_score_ptr + state_offset)
+ running_denom = tl.load(denom_ptr + state_offset)
+ running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32)
+
+ seq_len = tl.load(seq_lens_ptr + token_idx)
+ gather_len = tl.load(gather_lens_ptr + token_idx)
+ start_pos = seq_len - gather_len
+ fp8_mask = offsets < fp8_dim
+ rope_mask = (offsets >= fp8_dim) & dim_mask
+ rope_offsets = tl.maximum(offsets - fp8_dim, 0)
+
+ for candidate_idx in range(0, num_candidates):
+ gather_idx = candidate_offset + candidate_idx
+ is_valid = gather_idx < gather_len
+
+ if is_valid:
+ pos = start_pos + gather_idx
+ block_in_seq = pos // cache_block_size
+ pos_in_block = pos % cache_block_size
+ physical_block = tl.load(
+ block_table_ptr + token_idx * stride_block_table_t + block_in_seq
+ )
+ cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv, axis=0) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = running_acc * previous_weight + kv * candidate_weight
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ tl.store(max_score_ptr + state_offset, running_max)
+ tl.store(denom_ptr + state_offset, running_denom)
+ tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask)
+
+
+def accumulate_fp8ds_paged_sparse_mla_attention_chunk(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ gather_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ block_size: int,
+ scale: float,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ candidate_offset: int,
+ num_candidates: int,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert q.shape[-1] == 512
+ assert seq_lens.shape[0] == q.shape[0]
+ assert gather_lens.shape[0] == q.shape[0]
+ assert block_table.shape[0] == q.shape[0]
+ assert max_score.shape[0] == q.shape[0]
+ assert max_score.shape[1] <= q.shape[1]
+ assert denom.shape == max_score.shape
+ assert acc.shape == (*max_score.shape, q.shape[-1])
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert k_cache.dtype == torch.uint8
+ assert q.is_cuda and k_cache.is_cuda
+ assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+
+ token_fp8_dim = 448
+ token_bf16_dim = 64
+ token_scale_dim = 8
+ quant_block_size = 64
+ token_data_size = token_fp8_dim + token_bf16_dim * 2
+
+ num_tokens, _, head_dim = q.shape
+ num_heads = max_score.shape[1]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, num_heads)
+ _accumulate_fp8ds_paged_attention_chunk_kernel[grid](
+ q,
+ k_cache,
+ seq_lens,
+ gather_lens,
+ block_table,
+ max_score,
+ denom,
+ acc,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ block_table.stride(0),
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ block_size,
+ token_data_size,
+ k_cache.stride(0),
+ token_fp8_dim,
+ token_scale_dim,
+ quant_block_size,
+ num_heads,
+ head_dim,
+ num_candidates,
+ candidate_offset,
+ scale,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel(
+ q_ptr,
+ k_cache_ptr,
+ seq_lens_ptr,
+ gather_lens_ptr,
+ block_table_ptr,
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_block_table_t,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ cache_block_size: tl.constexpr,
+ token_data_size: tl.constexpr,
+ block_stride: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_candidates,
+ candidate_offset,
+ scale: tl.constexpr,
+ HEAD_BLOCK: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_block_idx = tl.program_id(1)
+ head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)
+ dim_offsets = tl.arange(0, BLOCK_D)
+ head_mask = head_offsets < num_heads
+ dim_mask = dim_offsets < head_dim
+ matrix_mask = head_mask[:, None] & dim_mask[None, :]
+
+ q = tl.load(
+ q_ptr
+ + token_idx * stride_q_t
+ + head_offsets[:, None] * stride_q_h
+ + dim_offsets[None, :] * stride_q_d,
+ mask=matrix_mask,
+ other=0.0,
+ ).to(tl.float32)
+
+ state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h
+ acc_offsets = (
+ token_idx * stride_acc_t
+ + head_offsets[:, None] * stride_acc_h
+ + dim_offsets[None, :] * stride_acc_d
+ )
+ running_max = tl.load(
+ max_score_ptr + state_offsets,
+ mask=head_mask,
+ other=-float("inf"),
+ )
+ running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0)
+ running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to(
+ tl.float32
+ )
+
+ seq_len = tl.load(seq_lens_ptr + token_idx)
+ gather_len = tl.load(gather_lens_ptr + token_idx)
+ start_pos = seq_len - gather_len
+ fp8_mask = dim_offsets < fp8_dim
+ rope_mask = (dim_offsets >= fp8_dim) & dim_mask
+ rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
+
+ for candidate_idx in range(0, num_candidates):
+ gather_idx = candidate_offset + candidate_idx
+ is_valid = gather_idx < gather_len
+
+ if is_valid:
+ pos = start_pos + gather_idx
+ block_in_seq = pos // cache_block_size
+ pos_in_block = pos % cache_block_size
+ physical_block = tl.load(
+ block_table_ptr + token_idx * stride_block_table_t + block_in_seq
+ )
+ cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = dim_offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv[None, :], axis=1) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = (
+ running_acc * previous_weight[:, None]
+ + kv[None, :] * candidate_weight[:, None]
+ )
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask)
+ tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask)
+ tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask)
+
+
+def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ gather_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ block_size: int,
+ scale: float,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ candidate_offset: int,
+ num_candidates: int,
+ head_block_size: int = 2,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert q.shape[-1] == 512
+ assert seq_lens.shape[0] == q.shape[0]
+ assert gather_lens.shape[0] == q.shape[0]
+ assert block_table.shape[0] == q.shape[0]
+ assert max_score.shape[0] == q.shape[0]
+ assert max_score.shape[1] <= q.shape[1]
+ assert denom.shape == max_score.shape
+ assert acc.shape == (*max_score.shape, q.shape[-1])
+ assert head_block_size in (1, 2, 4)
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert k_cache.dtype == torch.uint8
+ assert q.is_cuda and k_cache.is_cuda
+ assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+
+ token_fp8_dim = 448
+ token_bf16_dim = 64
+ token_scale_dim = 8
+ quant_block_size = 64
+ token_data_size = token_fp8_dim + token_bf16_dim * 2
+
+ num_tokens, _, head_dim = q.shape
+ num_heads = max_score.shape[1]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, triton.cdiv(num_heads, head_block_size))
+ _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid](
+ q,
+ k_cache,
+ seq_lens,
+ gather_lens,
+ block_table,
+ max_score,
+ denom,
+ acc,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ block_table.stride(0),
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ block_size,
+ token_data_size,
+ k_cache.stride(0),
+ token_fp8_dim,
+ token_scale_dim,
+ quant_block_size,
+ num_heads,
+ head_dim,
+ num_candidates,
+ candidate_offset,
+ scale,
+ HEAD_BLOCK=head_block_size,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _fp8ds_paged_attention_with_sink_multihead_kernel(
+ q_ptr,
+ k_cache_ptr,
+ seq_lens_ptr,
+ gather_lens_ptr,
+ block_table_ptr,
+ sink_ptr,
+ output_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_block_table_t,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ cache_block_size: tl.constexpr,
+ token_data_size: tl.constexpr,
+ block_stride: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ candidate_offset: tl.constexpr,
+ num_candidates: tl.constexpr,
+ scale: tl.constexpr,
+ HEAD_BLOCK: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_block_idx = tl.program_id(1)
+ head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)
+ dim_offsets = tl.arange(0, BLOCK_D)
+ head_mask = head_offsets < num_heads
+ dim_mask = dim_offsets < head_dim
+ matrix_mask = head_mask[:, None] & dim_mask[None, :]
+
+ q = tl.load(
+ q_ptr
+ + token_idx * stride_q_t
+ + head_offsets[:, None] * stride_q_h
+ + dim_offsets[None, :] * stride_q_d,
+ mask=matrix_mask,
+ other=0.0,
+ ).to(tl.float32)
+ running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32)
+ running_denom = tl.zeros((HEAD_BLOCK,), tl.float32)
+ running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32)
+
+ seq_len = tl.load(seq_lens_ptr + token_idx)
+ gather_len = tl.load(gather_lens_ptr + token_idx)
+ start_pos = seq_len - gather_len
+ fp8_mask = dim_offsets < fp8_dim
+ rope_mask = (dim_offsets >= fp8_dim) & dim_mask
+ rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
+
+ for candidate_idx in range(0, num_candidates):
+ gather_idx = candidate_offset + candidate_idx
+ is_valid = gather_idx < gather_len
+ if is_valid:
+ pos = start_pos + gather_idx
+ block_in_seq = pos // cache_block_size
+ pos_in_block = pos % cache_block_size
+ physical_block = tl.load(
+ block_table_ptr + token_idx * stride_block_table_t + block_in_seq
+ )
+ cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = dim_offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv[None, :], axis=1) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = (
+ running_acc * previous_weight[:, None]
+ + kv[None, :] * candidate_weight[:, None]
+ )
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf"))
+ has_tokens = running_denom > 0.0
+ has_sink = sink > -float("inf")
+ valid_max = tl.where(has_tokens, running_max, -float("inf"))
+ valid_sink = tl.where(has_sink, sink, -float("inf"))
+ merge_max = tl.maximum(valid_max, valid_sink)
+ has_any = has_tokens | has_sink
+ safe_merge_max = tl.where(has_any, merge_max, 0.0)
+ safe_running_max = tl.where(has_tokens, running_max, safe_merge_max)
+ safe_sink = tl.where(has_sink, sink, safe_merge_max)
+ subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0)
+ sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0)
+ total_weight = running_denom * subset_scale + sink_weight
+ inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0)
+ final = running_acc * subset_scale[:, None] * inv_total[:, None]
+
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_offsets[:, None] * stride_output_h
+ + dim_offsets[None, :] * stride_output_d,
+ final,
+ mask=matrix_mask,
+ )
+
+
+def fp8ds_paged_sparse_mla_attention_with_sink_multihead(
+ q: torch.Tensor,
+ k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ gather_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ block_size: int,
+ candidate_offset: int,
+ num_candidates: int,
+ scale: float,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+ head_block_size: int = 1,
+ num_heads: int | None = None,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert q.shape[-1] == 512
+ assert seq_lens.shape[0] == q.shape[0]
+ assert gather_lens.shape[0] == q.shape[0]
+ assert block_table.shape[0] == q.shape[0]
+ assert output.shape[0] == q.shape[0]
+ assert output.shape[2] == q.shape[-1]
+ assert head_block_size in (1, 2, 4)
+ assert k_cache.dtype == torch.uint8
+ assert q.is_cuda and k_cache.is_cuda
+ assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda
+ assert attn_sink.is_cuda and output.is_cuda
+
+ token_fp8_dim = 448
+ token_bf16_dim = 64
+ token_scale_dim = 8
+ quant_block_size = 64
+ token_data_size = token_fp8_dim + token_bf16_dim * 2
+
+ num_tokens, _, head_dim = q.shape
+ active_heads = num_heads if num_heads is not None else output.shape[1]
+ assert active_heads <= q.shape[1]
+ assert active_heads <= output.shape[1]
+ assert active_heads <= attn_sink.shape[0]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, triton.cdiv(active_heads, head_block_size))
+ _fp8ds_paged_attention_with_sink_multihead_kernel[grid](
+ q,
+ k_cache,
+ seq_lens,
+ gather_lens,
+ block_table,
+ attn_sink,
+ output,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ block_table.stride(0),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ block_size,
+ token_data_size,
+ k_cache.stride(0),
+ token_fp8_dim,
+ token_scale_dim,
+ quant_block_size,
+ active_heads,
+ head_dim,
+ candidate_offset,
+ num_candidates,
+ scale,
+ HEAD_BLOCK=head_block_size,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _fp8ds_global_paged_attention_with_sink_multihead_kernel(
+ q_ptr,
+ compressed_k_cache_ptr,
+ slot_ids_ptr,
+ topk_lens_ptr,
+ swa_k_cache_ptr,
+ seq_lens_ptr,
+ gather_lens_ptr,
+ block_table_ptr,
+ sink_ptr,
+ output_ptr,
+ stride_q_t: tl.constexpr,
+ stride_q_h: tl.constexpr,
+ stride_q_d: tl.constexpr,
+ stride_slot_t: tl.constexpr,
+ stride_slot_c: tl.constexpr,
+ stride_block_table_t,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ compressed_cache_block_size: tl.constexpr,
+ compressed_block_stride: tl.constexpr,
+ swa_cache_block_size: tl.constexpr,
+ swa_block_stride: tl.constexpr,
+ token_data_size: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ num_compressed_candidates: tl.constexpr,
+ num_swa_candidates: tl.constexpr,
+ scale: tl.constexpr,
+ HEAD_BLOCK: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ head_block_idx = tl.program_id(1)
+ head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK)
+ dim_offsets = tl.arange(0, BLOCK_D)
+ head_mask = head_offsets < num_heads
+ dim_mask = dim_offsets < head_dim
+ matrix_mask = head_mask[:, None] & dim_mask[None, :]
+
+ q = tl.load(
+ q_ptr
+ + token_idx * stride_q_t
+ + head_offsets[:, None] * stride_q_h
+ + dim_offsets[None, :] * stride_q_d,
+ mask=matrix_mask,
+ other=0.0,
+ ).to(tl.float32)
+ running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32)
+ running_denom = tl.zeros((HEAD_BLOCK,), tl.float32)
+ running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32)
+
+ fp8_mask = dim_offsets < fp8_dim
+ rope_mask = (dim_offsets >= fp8_dim) & dim_mask
+ rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0)
+ topk_len = tl.load(topk_lens_ptr + token_idx)
+
+ for candidate_idx in range(0, num_compressed_candidates):
+ slot_id = tl.load(
+ slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c
+ )
+ is_valid = (candidate_idx < topk_len) & (slot_id >= 0)
+ if is_valid:
+ block_idx = slot_id // compressed_cache_block_size
+ pos_in_block = slot_id % compressed_cache_block_size
+ cache_block_ptr = (
+ compressed_k_cache_ptr
+ + block_idx.to(tl.int64) * compressed_block_stride
+ )
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + compressed_cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = dim_offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv[None, :], axis=1) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = (
+ running_acc * previous_weight[:, None]
+ + kv[None, :] * candidate_weight[:, None]
+ )
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ seq_len = tl.load(seq_lens_ptr + token_idx)
+ gather_len = tl.load(gather_lens_ptr + token_idx)
+ start_pos = seq_len - gather_len
+ for candidate_idx in range(0, num_swa_candidates):
+ is_valid = candidate_idx < gather_len
+ if is_valid:
+ pos = start_pos + candidate_idx
+ block_in_seq = pos // swa_cache_block_size
+ pos_in_block = pos % swa_cache_block_size
+ physical_block = tl.load(
+ block_table_ptr + token_idx * stride_block_table_t + block_in_seq
+ )
+ cache_block_ptr = (
+ swa_k_cache_ptr + physical_block.to(tl.int64) * swa_block_stride
+ )
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr
+ + swa_cache_block_size * token_data_size
+ + pos_in_block * scale_dim
+ )
+
+ x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+ scale_offsets = dim_offsets // quant_block
+ encoded_scale = tl.load(
+ token_scale_ptr + scale_offsets,
+ mask=fp8_mask,
+ other=127,
+ )
+ dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * dequant_scale
+ rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to(
+ tl.float32
+ )
+ kv = tl.where(fp8_mask, x_dequant, rope)
+ kv = tl.where(dim_mask, kv, 0.0)
+
+ score = tl.sum(q * kv[None, :], axis=1) * scale
+ next_max = tl.maximum(running_max, score)
+ previous_weight = tl.exp(running_max - next_max)
+ candidate_weight = tl.exp(score - next_max)
+ running_acc = (
+ running_acc * previous_weight[:, None]
+ + kv[None, :] * candidate_weight[:, None]
+ )
+ running_denom = running_denom * previous_weight + candidate_weight
+ running_max = next_max
+
+ sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf"))
+ has_tokens = running_denom > 0.0
+ has_sink = sink > -float("inf")
+ valid_max = tl.where(has_tokens, running_max, -float("inf"))
+ valid_sink = tl.where(has_sink, sink, -float("inf"))
+ merge_max = tl.maximum(valid_max, valid_sink)
+ has_any = has_tokens | has_sink
+ safe_merge_max = tl.where(has_any, merge_max, 0.0)
+ safe_running_max = tl.where(has_tokens, running_max, safe_merge_max)
+ safe_sink = tl.where(has_sink, sink, safe_merge_max)
+ subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0)
+ sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0)
+ total_weight = running_denom * subset_scale + sink_weight
+ inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0)
+ final = running_acc * subset_scale[:, None] * inv_total[:, None]
+
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_offsets[:, None] * stride_output_h
+ + dim_offsets[None, :] * stride_output_d,
+ final,
+ mask=matrix_mask,
+ )
+
+
+def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead(
+ q: torch.Tensor,
+ compressed_k_cache: torch.Tensor,
+ slot_ids: torch.Tensor,
+ topk_lens: torch.Tensor,
+ compressed_block_size: int,
+ swa_k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ gather_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ swa_block_size: int,
+ num_compressed_candidates: int,
+ num_swa_candidates: int,
+ scale: float,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+ head_block_size: int = 1,
+ num_heads: int | None = None,
+) -> None:
+ if q.dim() == 4:
+ assert q.shape[1] == 1
+ q = q[:, 0]
+ if slot_ids.dim() == 3:
+ assert slot_ids.shape[1] == 1
+ slot_ids = slot_ids[:, 0]
+
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ assert q.shape[-1] == 512
+ assert slot_ids.dim() == 2
+ assert slot_ids.shape[0] == q.shape[0]
+ assert topk_lens.shape[0] == q.shape[0]
+ assert seq_lens.shape[0] == q.shape[0]
+ assert gather_lens.shape[0] == q.shape[0]
+ assert block_table.shape[0] == q.shape[0]
+ assert output.shape[0] == q.shape[0]
+ assert output.shape[2] == q.shape[-1]
+ assert head_block_size in (1, 2, 4)
+ assert compressed_k_cache.dtype == torch.uint8
+ assert swa_k_cache.dtype == torch.uint8
+ assert q.is_cuda and compressed_k_cache.is_cuda and swa_k_cache.is_cuda
+ assert slot_ids.is_cuda and topk_lens.is_cuda
+ assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda
+ assert attn_sink.is_cuda and output.is_cuda
+
+ token_fp8_dim = 448
+ token_bf16_dim = 64
+ token_scale_dim = 8
+ quant_block_size = 64
+ token_data_size = token_fp8_dim + token_bf16_dim * 2
+
+ num_tokens, _, head_dim = q.shape
+ active_heads = num_heads if num_heads is not None else output.shape[1]
+ assert active_heads <= q.shape[1]
+ assert active_heads <= output.shape[1]
+ assert active_heads <= attn_sink.shape[0]
+ block_d = min(1024, triton.next_power_of_2(head_dim))
+ grid = (num_tokens, triton.cdiv(active_heads, head_block_size))
+ _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid](
+ q,
+ compressed_k_cache,
+ slot_ids,
+ topk_lens,
+ swa_k_cache,
+ seq_lens,
+ gather_lens,
+ block_table,
+ attn_sink,
+ output,
+ q.stride(0),
+ q.stride(1),
+ q.stride(2),
+ slot_ids.stride(0),
+ slot_ids.stride(1),
+ block_table.stride(0),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ compressed_block_size,
+ compressed_k_cache.stride(0),
+ swa_block_size,
+ swa_k_cache.stride(0),
+ token_data_size,
+ token_fp8_dim,
+ token_scale_dim,
+ quant_block_size,
+ active_heads,
+ head_dim,
+ num_compressed_candidates,
+ num_swa_candidates,
+ scale,
+ HEAD_BLOCK=head_block_size,
+ BLOCK_D=block_d,
+ num_warps=8,
+ )
+
+
+@triton.jit
+def _finish_attention_state_kernel(
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ output_ptr,
+ lse_ptr,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ stride_lse_t: tl.constexpr,
+ stride_lse_h: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_head = tl.program_id(0)
+ block_d = tl.program_id(1)
+ token_idx = token_head // num_heads
+ head_idx = token_head - token_idx * num_heads
+ offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ state_offset = token_idx * stride_state_t + head_idx * stride_state_h
+ running_max = tl.load(max_score_ptr + state_offset)
+ running_denom = tl.load(denom_ptr + state_offset)
+ is_valid = running_denom > 0.0
+ inv_denom = tl.where(is_valid, 1.0 / running_denom, 0.0)
+ subset_lse = tl.where(
+ is_valid,
+ running_max + tl.log(running_denom),
+ -float("inf"),
+ )
+
+ acc = tl.load(
+ acc_ptr
+ + token_idx * stride_acc_t
+ + head_idx * stride_acc_h
+ + offsets * stride_acc_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ subset_output = acc * inv_denom
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_idx * stride_output_h
+ + offsets * stride_output_d,
+ subset_output,
+ mask=dim_mask,
+ )
+ if block_d == 0:
+ tl.store(
+ lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h,
+ subset_lse,
+ )
+
+
+def finish_gathered_sparse_mla_attention(
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ output: torch.Tensor,
+ lse: torch.Tensor,
+) -> None:
+ assert max_score.shape == denom.shape
+ assert acc.shape[:2] == max_score.shape
+ assert output.shape == acc.shape
+ assert lse.shape == max_score.shape
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert output.dtype == torch.float32
+ assert lse.dtype == torch.float32
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+ assert output.is_cuda and lse.is_cuda
+
+ num_tokens, num_heads, head_dim = acc.shape
+ block_d = min(128, triton.next_power_of_2(head_dim))
+ grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d))
+ _finish_attention_state_kernel[grid](
+ max_score,
+ denom,
+ acc,
+ output,
+ lse,
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ lse.stride(0),
+ lse.stride(1),
+ num_heads,
+ head_dim,
+ BLOCK_D=block_d,
+ num_warps=4,
+ )
+
+
+@triton.jit
+def _finish_attention_state_with_sink_kernel(
+ max_score_ptr,
+ denom_ptr,
+ acc_ptr,
+ sink_ptr,
+ output_ptr,
+ stride_state_t: tl.constexpr,
+ stride_state_h: tl.constexpr,
+ stride_acc_t: tl.constexpr,
+ stride_acc_h: tl.constexpr,
+ stride_acc_d: tl.constexpr,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_head = tl.program_id(0)
+ block_d = tl.program_id(1)
+ token_idx = token_head // num_heads
+ head_idx = token_head - token_idx * num_heads
+ offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ state_offset = token_idx * stride_state_t + head_idx * stride_state_h
+ running_max = tl.load(max_score_ptr + state_offset)
+ running_denom = tl.load(denom_ptr + state_offset)
+ sink = tl.load(sink_ptr + head_idx)
+ has_tokens = running_denom > 0.0
+ has_sink = sink > -float("inf")
+ valid_max = tl.where(has_tokens, running_max, -float("inf"))
+ valid_sink = tl.where(has_sink, sink, -float("inf"))
+ merge_max = tl.maximum(valid_max, valid_sink)
+ has_any = has_tokens | has_sink
+ safe_merge_max = tl.where(has_any, merge_max, 0.0)
+ safe_running_max = tl.where(has_tokens, running_max, safe_merge_max)
+ safe_sink = tl.where(has_sink, sink, safe_merge_max)
+ subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0)
+ subset_weight = running_denom * subset_scale
+ sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0)
+ total_weight = subset_weight + sink_weight
+ inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0)
+
+ acc_values = tl.load(
+ acc_ptr
+ + token_idx * stride_acc_t
+ + head_idx * stride_acc_h
+ + offsets * stride_acc_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ acc_values = tl.where(has_tokens, acc_values, 0.0)
+ output = acc_values * subset_scale * inv_total
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_idx * stride_output_h
+ + offsets * stride_output_d,
+ output,
+ mask=dim_mask,
+ )
+
+
+@triton.jit
+def _finish_two_attention_states_with_sink_kernel(
+ max0_ptr,
+ denom0_ptr,
+ acc0_ptr,
+ max1_ptr,
+ denom1_ptr,
+ acc1_ptr,
+ sink_ptr,
+ output_ptr,
+ stride_state0_t: tl.constexpr,
+ stride_state0_h: tl.constexpr,
+ stride_acc0_t: tl.constexpr,
+ stride_acc0_h: tl.constexpr,
+ stride_acc0_d: tl.constexpr,
+ stride_state1_t: tl.constexpr,
+ stride_state1_h: tl.constexpr,
+ stride_acc1_t: tl.constexpr,
+ stride_acc1_h: tl.constexpr,
+ stride_acc1_d: tl.constexpr,
+ stride_output_t: tl.constexpr,
+ stride_output_h: tl.constexpr,
+ stride_output_d: tl.constexpr,
+ num_heads: tl.constexpr,
+ head_dim: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_head = tl.program_id(0)
+ block_d = tl.program_id(1)
+ token_idx = token_head // num_heads
+ head_idx = token_head - token_idx * num_heads
+ offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ dim_mask = offsets < head_dim
+
+ state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h
+ state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h
+ max0 = tl.load(max0_ptr + state0_offset)
+ denom0 = tl.load(denom0_ptr + state0_offset)
+ max1 = tl.load(max1_ptr + state1_offset)
+ denom1 = tl.load(denom1_ptr + state1_offset)
+ sink = tl.load(sink_ptr + head_idx)
+
+ has0 = denom0 > 0.0
+ has1 = denom1 > 0.0
+ has_sink = sink > -float("inf")
+ valid_max0 = tl.where(has0, max0, -float("inf"))
+ valid_max1 = tl.where(has1, max1, -float("inf"))
+ valid_sink = tl.where(has_sink, sink, -float("inf"))
+ merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink)
+ has_any = has0 | has1 | has_sink
+ safe_merge_max = tl.where(has_any, merge_max, 0.0)
+ safe_max0 = tl.where(has0, max0, safe_merge_max)
+ safe_max1 = tl.where(has1, max1, safe_merge_max)
+ safe_sink = tl.where(has_sink, sink, safe_merge_max)
+ scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0)
+ scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0)
+ sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0)
+ total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight
+ inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0)
+
+ acc0 = tl.load(
+ acc0_ptr
+ + token_idx * stride_acc0_t
+ + head_idx * stride_acc0_h
+ + offsets * stride_acc0_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ acc1 = tl.load(
+ acc1_ptr
+ + token_idx * stride_acc1_t
+ + head_idx * stride_acc1_h
+ + offsets * stride_acc1_d,
+ mask=dim_mask,
+ other=0.0,
+ ).to(tl.float32)
+ acc0 = tl.where(has0, acc0, 0.0)
+ acc1 = tl.where(has1, acc1, 0.0)
+ output = (acc0 * scale0 + acc1 * scale1) * inv_total
+ tl.store(
+ output_ptr
+ + token_idx * stride_output_t
+ + head_idx * stride_output_h
+ + offsets * stride_output_d,
+ output,
+ mask=dim_mask,
+ )
+
+
+def finish_two_sparse_mla_attention_states_with_sink(
+ max_score0: torch.Tensor,
+ denom0: torch.Tensor,
+ acc0: torch.Tensor,
+ max_score1: torch.Tensor,
+ denom1: torch.Tensor,
+ acc1: torch.Tensor,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+) -> None:
+ assert max_score0.shape == denom0.shape
+ assert max_score1.shape == denom1.shape
+ assert max_score0.shape == max_score1.shape
+ assert acc0.shape == acc1.shape
+ assert acc0.shape[:2] == max_score0.shape
+ assert output.shape[0] == acc0.shape[0]
+ assert output.shape[1] >= acc0.shape[1]
+ assert output.shape[2] == acc0.shape[2]
+ assert attn_sink.shape[0] >= acc0.shape[1]
+ assert max_score0.dtype == torch.float32
+ assert denom0.dtype == torch.float32
+ assert acc0.dtype == torch.float32
+ assert max_score1.dtype == torch.float32
+ assert denom1.dtype == torch.float32
+ assert acc1.dtype == torch.float32
+ assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda
+ assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda
+ assert attn_sink.is_cuda and output.is_cuda
+
+ num_tokens, num_heads, head_dim = acc0.shape
+ block_d = min(128, triton.next_power_of_2(head_dim))
+ grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d))
+ _finish_two_attention_states_with_sink_kernel[grid](
+ max_score0,
+ denom0,
+ acc0,
+ max_score1,
+ denom1,
+ acc1,
+ attn_sink,
+ output,
+ max_score0.stride(0),
+ max_score0.stride(1),
+ acc0.stride(0),
+ acc0.stride(1),
+ acc0.stride(2),
+ max_score1.stride(0),
+ max_score1.stride(1),
+ acc1.stride(0),
+ acc1.stride(1),
+ acc1.stride(2),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ num_heads,
+ head_dim,
+ BLOCK_D=block_d,
+ num_warps=4,
+ )
+
+
+def finish_sparse_mla_attention_with_sink(
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+) -> None:
+ assert max_score.shape == denom.shape
+ assert acc.shape[:2] == max_score.shape
+ assert output.shape[0] == acc.shape[0]
+ assert output.shape[1] >= acc.shape[1]
+ assert output.shape[2] == acc.shape[2]
+ assert attn_sink.shape[0] >= acc.shape[1]
+ assert max_score.dtype == torch.float32
+ assert denom.dtype == torch.float32
+ assert acc.dtype == torch.float32
+ assert max_score.is_cuda and denom.is_cuda and acc.is_cuda
+ assert attn_sink.is_cuda and output.is_cuda
+
+ num_tokens, num_heads, head_dim = acc.shape
+ block_d = min(128, triton.next_power_of_2(head_dim))
+ grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d))
+ _finish_attention_state_with_sink_kernel[grid](
+ max_score,
+ denom,
+ acc,
+ attn_sink,
+ output,
+ max_score.stride(0),
+ max_score.stride(1),
+ acc.stride(0),
+ acc.stride(1),
+ acc.stride(2),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ num_heads,
+ head_dim,
+ BLOCK_D=block_d,
+ num_warps=4,
+ )
diff --git a/vllm/v1/attention/backends/mla/sparse_mla_reference.py b/vllm/v1/attention/backends/mla/sparse_mla_reference.py
new file mode 100644
index 000000000000..203b64188202
--- /dev/null
+++ b/vllm/v1/attention/backends/mla/sparse_mla_reference.py
@@ -0,0 +1,242 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""Reference sparse MLA attention helpers.
+
+The helpers in this module intentionally use PyTorch tensor operations. They
+are the correctness-first contract for portable sparse MLA fallbacks and tests;
+optimized Triton/CUDA kernels should preserve these semantics.
+"""
+
+import torch
+
+
+def new_reference_attention_state(
+ q: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ if q.dim() == 4:
+ q_bhd = q[:, 0, :, :].float()
+ else:
+ assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}"
+ q_bhd = q.float()
+
+ num_tokens = q_bhd.shape[0]
+ num_heads = q_bhd.shape[1]
+ head_dim = q_bhd.shape[2]
+ max_score = torch.full(
+ (num_tokens, num_heads),
+ float("-inf"),
+ dtype=torch.float32,
+ device=q.device,
+ )
+ denom = torch.zeros_like(max_score)
+ acc = torch.zeros(
+ (num_tokens, num_heads, head_dim),
+ dtype=torch.float32,
+ device=q.device,
+ )
+ return q_bhd, max_score, denom, acc
+
+
+def accumulate_reference_attention_chunk(
+ q_bhd: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+ scale: float,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ kv_btd = kv.float()
+ kv_btd = torch.where(
+ valid_tokens[:, :, None],
+ kv_btd,
+ torch.zeros((), dtype=kv_btd.dtype, device=kv_btd.device),
+ )
+ scores = torch.einsum("bhd,btd->bht", q_bhd, kv_btd) * scale
+ scores = scores.masked_fill(~valid_tokens[:, None, :], float("-inf"))
+
+ chunk_max = scores.amax(dim=-1)
+ next_max = torch.maximum(max_score, chunk_max)
+
+ previous_scale = torch.exp(max_score - next_max)
+ previous_scale = torch.nan_to_num(previous_scale)
+ weights = torch.exp(scores - next_max[:, :, None])
+ weights = torch.where(
+ valid_tokens[:, None, :],
+ weights,
+ torch.zeros((), dtype=weights.dtype, device=weights.device),
+ )
+ weights = torch.nan_to_num(weights)
+
+ acc = acc * previous_scale[:, :, None]
+ denom = denom * previous_scale
+ acc = acc + torch.einsum("bht,btd->bhd", weights, kv_btd)
+ denom = denom + weights.sum(dim=-1)
+ return next_max, denom, acc
+
+
+def finish_reference_attention_no_sink(
+ max_score: torch.Tensor,
+ denom: torch.Tensor,
+ acc: torch.Tensor,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ valid = denom > 0
+ safe_denom = torch.where(valid, denom, torch.ones_like(denom))
+ subset_output = acc / safe_denom[:, :, None]
+ subset_output = torch.where(
+ valid[:, :, None],
+ subset_output,
+ torch.zeros((), dtype=subset_output.dtype, device=subset_output.device),
+ )
+ subset_lse = torch.where(
+ valid,
+ max_score + torch.log(safe_denom),
+ torch.full_like(max_score, float("-inf")),
+ )
+ return subset_output, subset_lse
+
+
+def reference_attention_no_sink(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ q_bhd, max_score, denom, acc = new_reference_attention_state(q)
+ max_score, denom, acc = accumulate_reference_attention_chunk(
+ q_bhd=q_bhd,
+ kv=kv,
+ valid_tokens=valid_tokens,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ scale=scale,
+ )
+ return finish_reference_attention_no_sink(max_score, denom, acc)
+
+
+def merge_reference_attention_with_sink(
+ subset_outputs: list[torch.Tensor],
+ subset_lses: list[torch.Tensor],
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+) -> None:
+ assert subset_outputs, "At least one attention subset is required"
+ assert len(subset_outputs) == len(subset_lses)
+
+ sink = attn_sink[None, :].float()
+ merge_max = sink
+ for subset_lse in subset_lses:
+ merge_max = torch.maximum(merge_max, subset_lse)
+
+ safe_merge_max = torch.where(
+ torch.isfinite(merge_max), merge_max, torch.zeros_like(merge_max)
+ )
+ merged_acc = torch.zeros_like(subset_outputs[0], dtype=torch.float32)
+ sink_weight = torch.exp(sink - safe_merge_max)
+ sink_weight = torch.nan_to_num(sink_weight)
+ merged_denom = sink_weight
+ for subset_output, subset_lse in zip(subset_outputs, subset_lses):
+ subset_weight = torch.exp(subset_lse - safe_merge_max)
+ subset_weight = torch.nan_to_num(subset_weight)
+ merged_acc = merged_acc + subset_output.float() * subset_weight[:, :, None]
+ merged_denom = merged_denom + subset_weight
+
+ safe_denom = torch.where(
+ merged_denom > 0, merged_denom, torch.ones_like(merged_denom)
+ )
+ reference_output = merged_acc / safe_denom[:, :, None]
+ reference_output = torch.where(
+ (merged_denom > 0)[:, :, None],
+ reference_output,
+ torch.zeros((), dtype=reference_output.dtype, device=reference_output.device),
+ )
+ output.copy_(reference_output.to(dtype=output.dtype))
+
+
+def sink_aware_reference_attention(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ valid_tokens: torch.Tensor,
+ scale: float,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+) -> None:
+ subset_output, subset_lse = reference_attention_no_sink(
+ q=q,
+ kv=kv,
+ valid_tokens=valid_tokens,
+ scale=scale,
+ )
+ merge_reference_attention_with_sink(
+ subset_outputs=[subset_output],
+ subset_lses=[subset_lse],
+ attn_sink=attn_sink,
+ output=output,
+ )
+
+
+def reference_sparse_mla_prefill(
+ q: torch.Tensor,
+ kv: torch.Tensor,
+ combined_indices: torch.Tensor,
+ combined_lens: torch.Tensor,
+ scale: float,
+ attn_sink: torch.Tensor,
+ output: torch.Tensor,
+ topk_chunk_size: int,
+ query_chunk_size: int,
+) -> None:
+ kv_flat = kv.reshape(-1, q.shape[-1])
+ topk_chunk_size = min(combined_indices.shape[-1], topk_chunk_size)
+ query_chunk_size = min(q.shape[0], query_chunk_size)
+
+ for token_start in range(0, q.shape[0], query_chunk_size):
+ token_end = min(token_start + query_chunk_size, q.shape[0])
+ q_chunk = q[token_start:token_end]
+ lens_chunk = combined_lens[token_start:token_end]
+ indices_chunk_full = combined_indices[token_start:token_end]
+ q_bhd, max_score, denom, acc = new_reference_attention_state(q_chunk)
+
+ for index_start in range(0, combined_indices.shape[-1], topk_chunk_size):
+ index_end = min(
+ index_start + topk_chunk_size,
+ combined_indices.shape[-1],
+ )
+ indices_chunk = indices_chunk_full[:, index_start:index_end]
+ index_offsets = torch.arange(
+ index_start,
+ index_end,
+ device=q.device,
+ )
+ valid_tokens = (
+ (index_offsets[None, :] < lens_chunk[:, None])
+ & (indices_chunk >= 0)
+ )
+ safe_indices = torch.where(
+ valid_tokens,
+ indices_chunk,
+ torch.zeros((), dtype=indices_chunk.dtype, device=q.device),
+ ).long()
+ gathered_kv = kv_flat[safe_indices]
+ max_score, denom, acc = accumulate_reference_attention_chunk(
+ q_bhd=q_bhd,
+ kv=gathered_kv,
+ valid_tokens=valid_tokens,
+ max_score=max_score,
+ denom=denom,
+ acc=acc,
+ scale=scale,
+ )
+
+ subset_output, subset_lse = finish_reference_attention_no_sink(
+ max_score,
+ denom,
+ acc,
+ )
+ merge_reference_attention_with_sink(
+ subset_outputs=[subset_output],
+ subset_lses=[subset_lse],
+ attn_sink=attn_sink,
+ output=output[token_start:token_end],
+ )
diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py
index 28564e6a97d3..7689cf9e155a 100644
--- a/vllm/v1/attention/backends/mla/sparse_swa.py
+++ b/vllm/v1/attention/backends/mla/sparse_swa.py
@@ -16,9 +16,15 @@
CommonAttentionMetadata,
MultipleOf,
)
+from vllm.v1.attention.backends.mla.sparse_mla_env import (
+ is_triton_sparse_mla_enabled,
+ is_triton_sparse_mla_enabled_for_platform,
+ triton_sparse_mla_cudagraphs_allowed,
+)
from vllm.v1.attention.backends.utils import split_decodes_and_prefills
from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata
from vllm.v1.kv_cache_interface import (
+ AttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowMLASpec,
@@ -162,6 +168,8 @@ class DeepseekSparseSWAMetadata:
# Pre-computed prefill metadata shared across all DeepseekV4 attention layers.
prefill_seq_lens: torch.Tensor | None = None
prefill_gather_lens: torch.Tensor | None = None
+ prefill_seq_lens_cpu: torch.Tensor | None = None
+ prefill_gather_lens_cpu: torch.Tensor | None = None
# Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta
# per present DeepseekV4 layer type, shared across all ~60 layers of that type
@@ -195,6 +203,20 @@ class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder):
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
+ @classmethod
+ def get_cudagraph_support(
+ cls,
+ vllm_config: VllmConfig,
+ kv_cache_spec: AttentionSpec,
+ ) -> AttentionCGSupport:
+ if (
+ getattr(kv_cache_spec, "model_version", None) == "deepseek_v4"
+ and is_triton_sparse_mla_enabled_for_platform()
+ and not triton_sparse_mla_cudagraphs_allowed(vllm_config)
+ ):
+ return AttentionCGSupport.NEVER
+ return cls._cudagraph_support
+
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec)
@@ -313,6 +335,8 @@ def build(
num_prefills,
seq_lens,
query_start_loc,
+ query_start_loc_cpu,
+ common_attn_metadata.seq_lens_cpu_upper_bound,
)
# Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta
@@ -363,6 +387,8 @@ def build_tile_scheduler(
}
if num_decode_tokens == 0 or current_platform.is_rocm():
return out
+ if is_triton_sparse_mla_enabled(self.device):
+ return out
for layer_type in self._layer_types:
# get_mla_metadata() is the official FlashMLA entry point that
# returns a fresh empty FlashMLASchedMeta; using it keeps this
@@ -377,6 +403,8 @@ def _build_deepseek_v4_metadata(
num_prefills: int,
seq_lens: torch.Tensor,
query_start_loc: torch.Tensor,
+ query_start_loc_cpu: torch.Tensor,
+ seq_lens_cpu_upper_bound: torch.Tensor | None,
) -> dict[str, torch.Tensor | None]:
"""Pre-compute DeepseekV4 prefill metadata during the metadata build phase.
@@ -403,8 +431,27 @@ def _build_deepseek_v4_metadata(
BLOCK_SIZE=triton.next_power_of_2(num_prefills),
)
+ assert seq_lens_cpu_upper_bound is not None
+ seq_lens_cpu = seq_lens_cpu_upper_bound
+ prefill_seq_lens_cpu = seq_lens_cpu[
+ num_decodes : num_decodes + num_prefills
+ ]
+ query_lens_cpu = (
+ query_start_loc_cpu[
+ num_decodes + 1 : num_decodes + num_prefills + 1
+ ]
+ - query_start_loc_cpu[num_decodes : num_decodes + num_prefills]
+ )
+ prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu
+ prefill_gather_lens_cpu = query_lens_cpu + torch.minimum(
+ prefix_lens_cpu,
+ torch.full_like(prefix_lens_cpu, self.window_size - 1),
+ )
+
result["prefill_seq_lens"] = seq_lens[num_decodes:]
result["prefill_gather_lens"] = pfx_gather_lens
+ result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu
+ result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu
return result
diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py
index 959a79f292a5..da04498f384f 100644
--- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py
+++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py
@@ -5,7 +5,10 @@
combine_topk_swa_indices,
compute_global_topk_indices_and_lens,
dequantize_and_gather_k_cache,
+ dequantize_combined_sparse_mla_decode_kv,
+ dequantize_global_slots_k_cache,
quantize_and_insert_k_cache,
+ sparse_prefill_combined_topk_size,
)
from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant
from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant
@@ -16,8 +19,11 @@
"combine_topk_swa_indices",
"compute_global_topk_indices_and_lens",
"dequantize_and_gather_k_cache",
+ "dequantize_combined_sparse_mla_decode_kv",
+ "dequantize_global_slots_k_cache",
"fused_indexer_q_rope_quant",
"fused_inv_rope_fp8_quant",
"fused_q_kv_rmsnorm",
"quantize_and_insert_k_cache",
+ "sparse_prefill_combined_topk_size",
]
diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
index 69d20c107e11..33cfe699236f 100644
--- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
+++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py
@@ -349,12 +349,166 @@ def dequantize_and_gather_k_cache(
)
+@triton.jit
+def _dequantize_global_slots_k_kernel(
+ out_ptr,
+ out_stride_token,
+ out_stride_slot,
+ k_cache_ptr,
+ slot_ids_ptr,
+ slot_ids_stride_token,
+ slot_ids_stride_slot,
+ cache_block_size: tl.constexpr,
+ token_data_size: tl.constexpr,
+ block_stride: tl.constexpr,
+ fp8_dim: tl.constexpr,
+ bf16_dim: tl.constexpr,
+ scale_dim: tl.constexpr,
+ quant_block: tl.constexpr,
+ output_dim: tl.constexpr,
+ BLOCK_D: tl.constexpr,
+):
+ token_idx = tl.program_id(0)
+ topk_idx = tl.program_id(1)
+
+ slot_id = tl.load(
+ slot_ids_ptr
+ + token_idx * slot_ids_stride_token
+ + topk_idx * slot_ids_stride_slot
+ )
+ offsets = tl.arange(0, BLOCK_D)
+ output_row = out_ptr + token_idx * out_stride_token + topk_idx * out_stride_slot
+
+ if slot_id < 0:
+ tl.store(
+ output_row + offsets,
+ tl.zeros((BLOCK_D,), dtype=tl.float32).to(tl.bfloat16),
+ mask=offsets < output_dim,
+ )
+ return
+
+ block_idx = slot_id // cache_block_size
+ pos_in_block = slot_id % cache_block_size
+ cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride
+ token_data_ptr = cache_block_ptr + pos_in_block * token_data_size
+ token_scale_ptr = (
+ cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim
+ )
+
+ fp8_offsets = tl.arange(0, 512)
+ fp8_mask = fp8_offsets < fp8_dim
+ x_uint8 = tl.load(token_data_ptr + fp8_offsets, mask=fp8_mask, other=0)
+ x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True)
+ x_float = x_fp8.to(tl.float32)
+
+ scale_offsets = fp8_offsets // quant_block
+ encoded_scale = tl.load(token_scale_ptr + scale_offsets, mask=fp8_mask, other=127)
+ scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0)
+ x_dequant = x_float * scale
+ tl.store(output_row + fp8_offsets, x_dequant.to(tl.bfloat16), mask=fp8_mask)
+
+ bf16_offsets = tl.arange(0, 64)
+ bf16_cache_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16))
+ bf16_vals = tl.load(bf16_cache_ptr + bf16_offsets, mask=bf16_offsets < bf16_dim)
+ tl.store(
+ output_row + fp8_dim + bf16_offsets,
+ bf16_vals,
+ mask=bf16_offsets < bf16_dim,
+ )
+
+
+def dequantize_global_slots_k_cache(
+ out: torch.Tensor,
+ k_cache: torch.Tensor,
+ slot_ids: torch.Tensor,
+ block_size: int,
+) -> None:
+ """Dequantize fp8_ds_mla cache rows addressed by physical global slot ids."""
+ if slot_ids.dim() == 3:
+ assert slot_ids.shape[1] == 1
+ slot_ids = slot_ids[:, 0, :]
+ assert slot_ids.dim() == 2, (
+ f"slot_ids must be [num_tokens, topk], got {slot_ids.shape}"
+ )
+ assert out.shape[:2] == slot_ids.shape
+ assert out.shape[-1] == 512
+ assert out.dtype == torch.bfloat16
+ assert k_cache.dtype == torch.uint8
+
+ TOKEN_FP8_DIM = 448
+ TOKEN_BF16_DIM = 64
+ TOKEN_SCALE_DIM = 8
+ QUANT_BLOCK_SIZE = 64
+ TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2
+
+ grid = slot_ids.shape
+ _dequantize_global_slots_k_kernel[grid](
+ out,
+ out.stride(0),
+ out.stride(1),
+ k_cache,
+ slot_ids,
+ slot_ids.stride(0),
+ slot_ids.stride(1),
+ cache_block_size=block_size,
+ token_data_size=TOKEN_DATA_SIZE,
+ block_stride=k_cache.stride(0),
+ fp8_dim=TOKEN_FP8_DIM,
+ bf16_dim=TOKEN_BF16_DIM,
+ scale_dim=TOKEN_SCALE_DIM,
+ quant_block=QUANT_BLOCK_SIZE,
+ output_dim=512,
+ BLOCK_D=triton.next_power_of_2(512),
+ )
+
+
+def dequantize_combined_sparse_mla_decode_kv(
+ combined_kv: torch.Tensor,
+ compressed_k_cache: torch.Tensor,
+ compressed_slot_ids: torch.Tensor,
+ compressed_block_size: int,
+ swa_k_cache: torch.Tensor,
+ seq_lens: torch.Tensor,
+ swa_lens: torch.Tensor,
+ block_table: torch.Tensor,
+ swa_block_size: int,
+) -> None:
+ """Fill `[compressed, SWA]` decode candidates into one output buffer."""
+ assert combined_kv.dim() == 3
+ compressed_topk = compressed_slot_ids.shape[-1]
+ assert combined_kv.shape[0] == compressed_slot_ids.shape[0]
+ assert combined_kv.shape[-1] == 512
+ assert combined_kv.dtype == torch.bfloat16
+ assert combined_kv.shape[1] >= compressed_topk
+
+ dequantize_global_slots_k_cache(
+ combined_kv[:, :compressed_topk],
+ compressed_k_cache,
+ compressed_slot_ids,
+ compressed_block_size,
+ )
+ swa_out = combined_kv[:, compressed_topk:]
+ if swa_out.shape[1] == 0:
+ return
+ dequantize_and_gather_k_cache(
+ swa_out,
+ swa_k_cache,
+ seq_lens=seq_lens,
+ gather_lens=swa_lens,
+ block_table=block_table,
+ block_size=swa_block_size,
+ offset=0,
+ )
+
+
def compute_global_topk_indices_and_lens(
topk_indices: torch.Tensor,
token_to_req_indices: torch.Tensor,
block_table: torch.Tensor,
block_size: int,
is_valid_token: torch.Tensor,
+ global_topk_indices: torch.Tensor | None = None,
+ topk_lens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Map local topk indices to global KV cache slots and count valid entries.
@@ -364,8 +518,20 @@ def compute_global_topk_indices_and_lens(
3. Masking padding tokens to length 0
"""
num_tokens = topk_indices.shape[0]
- global_topk_indices = torch.empty_like(topk_indices)
- topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device)
+ if global_topk_indices is None:
+ global_topk_indices = torch.empty_like(topk_indices)
+ else:
+ assert global_topk_indices.shape == topk_indices.shape
+ assert global_topk_indices.dtype == topk_indices.dtype
+ assert global_topk_indices.device == topk_indices.device
+ if topk_lens is None:
+ topk_lens = torch.empty(
+ num_tokens, dtype=torch.int32, device=topk_indices.device
+ )
+ else:
+ assert topk_lens.shape == (num_tokens,)
+ assert topk_lens.dtype == torch.int32
+ assert topk_lens.device == topk_indices.device
_compute_global_topk_indices_and_lens_kernel[(num_tokens,)](
global_topk_indices,
global_topk_indices.stride(0),
@@ -412,7 +578,7 @@ def _compute_global_topk_indices_and_lens_kernel(
mask=mask,
other=-1,
)
- is_valid = local_idx >= 0
+ is_valid = (local_idx >= 0) & is_valid_token
block_indices = local_idx // block_size
block_numbers = tl.load(
@@ -442,6 +608,14 @@ def _compute_global_topk_indices_and_lens_kernel(
_SPARSE_PREFILL_TOPK_ALIGNMENT = 128
+def sparse_prefill_combined_topk_size(topk: int, window_size: int) -> int:
+ return (
+ (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1)
+ // _SPARSE_PREFILL_TOPK_ALIGNMENT
+ * _SPARSE_PREFILL_TOPK_ALIGNMENT
+ )
+
+
def combine_topk_swa_indices(
topk_indices: torch.Tensor,
query_start_loc: torch.Tensor,
@@ -452,23 +626,35 @@ def combine_topk_swa_indices(
topk: int,
M: int,
N: int,
+ combined_indices: torch.Tensor | None = None,
+ combined_lens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = topk_indices.shape[0]
num_reqs = seq_lens.shape[0]
- combined_topk = (
- (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1)
- // _SPARSE_PREFILL_TOPK_ALIGNMENT
- * _SPARSE_PREFILL_TOPK_ALIGNMENT
- )
- combined_indices = torch.full(
- (num_tokens, combined_topk),
- fill_value=-1,
- dtype=torch.int32,
- device=topk_indices.device,
- )
- combined_lens = torch.empty(
- num_tokens, dtype=torch.int32, device=topk_indices.device
- )
+ combined_topk = sparse_prefill_combined_topk_size(topk, window_size)
+ if combined_indices is None:
+ combined_indices = torch.full(
+ (num_tokens, combined_topk),
+ fill_value=-1,
+ dtype=torch.int32,
+ device=topk_indices.device,
+ )
+ else:
+ assert combined_indices.shape[0] >= num_tokens
+ assert combined_indices.shape[1] >= combined_topk
+ assert combined_indices.dtype == torch.int32
+ assert combined_indices.device == topk_indices.device
+ combined_indices = combined_indices[:num_tokens, :combined_topk]
+ combined_indices.fill_(-1)
+ if combined_lens is None:
+ combined_lens = torch.empty(
+ num_tokens, dtype=torch.int32, device=topk_indices.device
+ )
+ else:
+ assert combined_lens.shape[0] >= num_tokens
+ assert combined_lens.dtype == torch.int32
+ assert combined_lens.device == topk_indices.device
+ combined_lens = combined_lens[:num_tokens]
NUM_WORKERS = 128
_combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)](
diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py
new file mode 100644
index 000000000000..652dc7be7907
--- /dev/null
+++ b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py
@@ -0,0 +1,175 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+"""SM12x Triton FP8 einsum kernels for DeepSeek V4."""
+
+import torch
+
+from vllm.triton_utils import tl, triton
+
+
+def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor:
+ exp_bits = scale.view(torch.uint8).to(torch.int32)
+ fp32_bits = exp_bits << 23
+ return fp32_bits.view(torch.float32)
+
+
+@triton.jit
+def _deepseek_v4_sm12x_fp8_einsum_kernel(
+ a_ptr,
+ a_scale_ptr,
+ b_ptr,
+ b_scale_ptr,
+ out_ptr,
+ num_tokens: tl.constexpr,
+ num_groups: tl.constexpr,
+ out_rank: tl.constexpr,
+ hidden_size: tl.constexpr,
+ a_stride_token: tl.constexpr,
+ a_stride_group: tl.constexpr,
+ a_stride_hidden: tl.constexpr,
+ a_scale_stride_token: tl.constexpr,
+ a_scale_stride_group: tl.constexpr,
+ a_scale_stride_hidden: tl.constexpr,
+ b_stride_group: tl.constexpr,
+ b_stride_out: tl.constexpr,
+ b_stride_hidden: tl.constexpr,
+ b_scale_stride_group: tl.constexpr,
+ b_scale_stride_out: tl.constexpr,
+ b_scale_stride_hidden: tl.constexpr,
+ out_stride_token: tl.constexpr,
+ out_stride_group: tl.constexpr,
+ out_stride_rank: tl.constexpr,
+ BLOCK_TOKENS: tl.constexpr,
+ BLOCK_OUT: tl.constexpr,
+ BLOCK_HIDDEN: tl.constexpr,
+) -> None:
+ token_block = tl.program_id(0)
+ out_block = tl.program_id(1)
+ group = tl.program_id(2)
+
+ token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)
+ out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
+ hidden_offsets = tl.arange(0, BLOCK_HIDDEN)
+ accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32)
+
+ for hidden_start in range(0, hidden_size, BLOCK_HIDDEN):
+ hidden = hidden_start + hidden_offsets
+ a = tl.load(
+ a_ptr
+ + token_offsets[:, None] * a_stride_token
+ + group * a_stride_group
+ + hidden[None, :] * a_stride_hidden,
+ mask=(token_offsets[:, None] < num_tokens)
+ & (hidden[None, :] < hidden_size),
+ other=0.0,
+ )
+ b = tl.load(
+ b_ptr
+ + group * b_stride_group
+ + out_offsets[None, :] * b_stride_out
+ + hidden[:, None] * b_stride_hidden,
+ mask=(out_offsets[None, :] < out_rank) & (hidden[:, None] < hidden_size),
+ other=0.0,
+ )
+ raw = tl.dot(a, b, out_dtype=tl.float32)
+ hidden_scale_block = hidden_start // BLOCK_HIDDEN
+ a_scale = tl.load(
+ a_scale_ptr
+ + token_offsets * a_scale_stride_token
+ + group * a_scale_stride_group
+ + hidden_scale_block * a_scale_stride_hidden,
+ mask=token_offsets < num_tokens,
+ other=0.0,
+ )
+ b_scale = tl.load(
+ b_scale_ptr
+ + group * b_scale_stride_group
+ + (out_offsets // 128) * b_scale_stride_out
+ + hidden_scale_block * b_scale_stride_hidden,
+ mask=out_offsets < out_rank,
+ other=0.0,
+ )
+ accum += raw * a_scale[:, None] * b_scale[None, :]
+
+ tl.store(
+ out_ptr
+ + token_offsets[:, None] * out_stride_token
+ + group * out_stride_group
+ + out_offsets[None, :] * out_stride_rank,
+ accum,
+ mask=(token_offsets[:, None] < num_tokens) & (out_offsets[None, :] < out_rank),
+ )
+
+
+def deepseek_v4_sm12x_fp8_einsum(
+ a: torch.Tensor,
+ a_scale: torch.Tensor,
+ b: torch.Tensor,
+ b_scale: torch.Tensor,
+ out: torch.Tensor,
+) -> None:
+ """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x.
+
+ ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape
+ ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to
+ ``[groups, out_rank, hidden]``.
+ """
+ num_tokens, num_groups, hidden_size = a.shape
+ b_groups, out_rank, b_hidden_size = b.shape
+ assert b_groups == num_groups
+ assert b_hidden_size == hidden_size
+ assert out.shape == (num_tokens, num_groups, out_rank)
+ assert hidden_size % 128 == 0
+ assert out_rank % 128 == 0
+ assert a.dtype == torch.float8_e4m3fn
+ assert b.dtype == torch.float8_e4m3fn
+ e8m0_dtype = getattr(torch, "float8_e8m0fnu", None)
+ if a_scale.dtype == e8m0_dtype:
+ a_scale = _upcast_e8m0_to_fp32(a_scale)
+ if b_scale.dtype == e8m0_dtype:
+ b_scale = _upcast_e8m0_to_fp32(b_scale)
+ assert a_scale.dtype == torch.float32
+ assert b_scale.dtype == torch.float32
+
+ if num_tokens == 0:
+ return
+
+ block_tokens = 16
+ block_out = 128
+ block_hidden = 128
+ grid = (
+ triton.cdiv(num_tokens, block_tokens),
+ triton.cdiv(out_rank, block_out),
+ num_groups,
+ )
+ _deepseek_v4_sm12x_fp8_einsum_kernel[grid](
+ a,
+ a_scale,
+ b,
+ b_scale,
+ out,
+ num_tokens,
+ num_groups,
+ out_rank,
+ hidden_size,
+ a.stride(0),
+ a.stride(1),
+ a.stride(2),
+ a_scale.stride(0),
+ a_scale.stride(1),
+ a_scale.stride(2),
+ b.stride(0),
+ b.stride(1),
+ b.stride(2),
+ b_scale.stride(0),
+ b_scale.stride(1),
+ b_scale.stride(2),
+ out.stride(0),
+ out.stride(1),
+ out.stride(2),
+ BLOCK_TOKENS=block_tokens,
+ BLOCK_OUT=block_out,
+ BLOCK_HIDDEN=block_hidden,
+ num_warps=4,
+ num_stages=3,
+ )
diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py
index 65993e804153..3f2ac7cdedb7 100644
--- a/vllm/v1/core/kv_cache_coordinator.py
+++ b/vllm/v1/core/kv_cache_coordinator.py
@@ -250,6 +250,17 @@ def remove_skipped_blocks(
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, total_computed_tokens)
+ def release_protected_prompt_blocks(
+ self, target_free_blocks: int | None = None
+ ) -> None:
+ for manager in self.single_type_managers:
+ if (
+ target_free_blocks is not None
+ and self.block_pool.get_num_free_blocks() >= target_free_blocks
+ ):
+ return
+ manager.release_protected_prompt_blocks(target_free_blocks)
+
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
@@ -475,6 +486,8 @@ def verify_and_split_kv_cache_groups(self) -> None:
# block cache hit yet.
block_sizes = [spec.block_size for spec, _, _ in attention_groups]
self.lcm_block_size = lcm(*block_sizes)
+ for manager in self.single_type_managers:
+ manager.cache_alignment_tokens = self.lcm_block_size
# Attention-group indices (into ``self.attention_groups``) that
# contain at least one EAGLE/MTP KV cache group.
diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py
index 431776870cf4..2f3c5abf3066 100644
--- a/vllm/v1/core/kv_cache_manager.py
+++ b/vllm/v1/core/kv_cache_manager.py
@@ -345,7 +345,7 @@ def allocate_slots(
num_tokens_main_model=full_num_tokens,
apply_admission_cap=True,
)
- if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
+ if not self._has_enough_free_blocks(num_blocks_to_allocate):
return None
num_tokens_main_model = total_computed_tokens + num_new_tokens
@@ -373,7 +373,7 @@ def allocate_slots(
num_tokens_main_model=num_tokens_main_model,
)
- if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
+ if not self._has_enough_free_blocks(num_blocks_to_allocate):
# Cannot allocate new blocks
return None
@@ -446,6 +446,12 @@ def evict_blocks(self, block_ids: set[int]) -> None:
"""
self.block_pool.evict_blocks(block_ids)
+ def _has_enough_free_blocks(self, num_blocks: int) -> bool:
+ if num_blocks <= self.block_pool.get_num_free_blocks():
+ return True
+ self.coordinator.release_protected_prompt_blocks(num_blocks)
+ return num_blocks <= self.block_pool.get_num_free_blocks()
+
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,
@@ -455,6 +461,7 @@ def reset_prefix_cache(self) -> bool:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
+ self.coordinator.release_protected_prompt_blocks()
if not self.block_pool.reset_prefix_cache():
return False
if self.log_stats:
diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py
index e8d3a6f75688..651f52f6b19d 100644
--- a/vllm/v1/core/single_type_kv_cache_manager.py
+++ b/vllm/v1/core/single_type_kv_cache_manager.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod
-from collections import defaultdict
+from collections import defaultdict, deque
from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
@@ -42,6 +42,7 @@ def __init__(
dcp_world_size: int = 1,
pcp_world_size: int = 1,
max_admission_blocks_per_request: int | None = None,
+ max_model_len: int | None = None,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
@@ -65,6 +66,8 @@ def __init__(
self.block_pool = block_pool
self.enable_caching = enable_caching
self._max_admission_blocks_per_request = max_admission_blocks_per_request
+ self.max_model_len = max_model_len
+ self.cache_alignment_tokens = self.block_size
self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated
@@ -80,6 +83,8 @@ def __init__(
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
+ self._protected_prompt_block_ids: set[int] = set()
+ self._protected_prompt_block_queue: deque[int] = deque()
@classmethod
def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]):
@@ -274,6 +279,70 @@ def take_new_block_ids(self) -> list[int]:
self.new_block_ids = []
return ids
+ def _max_protected_prompt_blocks(self) -> int | None:
+ if self.max_model_len is None:
+ return None
+ return 2 * cdiv(max(1, self.max_model_len), self.block_size)
+
+ def _protect_prompt_blocks(self, blocks: Sequence[KVCacheBlock]) -> None:
+ if not self.enable_caching:
+ return
+
+ protected: list[KVCacheBlock] = []
+ for block in blocks:
+ if (
+ block.is_null
+ or block.block_hash is None
+ or block.block_id in self._protected_prompt_block_ids
+ ):
+ continue
+ protected.append(block)
+ self._protected_prompt_block_ids.add(block.block_id)
+ self._protected_prompt_block_queue.append(block.block_id)
+
+ if not protected:
+ return
+
+ # Keep an extra reference for prompt blocks that must survive after
+ # their request releases its normal runtime reference. Later request
+ # reuse increments/decrements the runtime reference as usual.
+ self.block_pool.touch(protected)
+ self._trim_protected_prompt_blocks()
+
+ def _trim_protected_prompt_blocks(self) -> None:
+ max_blocks = self._max_protected_prompt_blocks()
+ if max_blocks is None:
+ return
+
+ while len(self._protected_prompt_block_ids) > max_blocks:
+ if not self._release_one_protected_prompt_block():
+ return
+
+ def _release_one_protected_prompt_block(self) -> bool:
+ while self._protected_prompt_block_queue:
+ block_id = self._protected_prompt_block_queue.popleft()
+ if block_id not in self._protected_prompt_block_ids:
+ continue
+
+ self._protected_prompt_block_ids.remove(block_id)
+ block = self.block_pool.blocks[block_id]
+ if block.ref_cnt > 0:
+ self.block_pool.free_blocks([block])
+ return True
+ return False
+
+ def release_protected_prompt_blocks(
+ self, target_free_blocks: int | None = None
+ ) -> None:
+ while self._protected_prompt_block_ids:
+ if (
+ target_free_blocks is not None
+ and self.block_pool.get_num_free_blocks() >= target_free_blocks
+ ):
+ return
+ if not self._release_one_protected_prompt_block():
+ return
+
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
@@ -504,6 +573,54 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
return num_common_blocks
+class MLAAttentionManager(FullAttentionManager):
+ """KV cache manager for DeepSeek V4 compressed MLA cache."""
+
+ def _should_protect_prompt_blocks(self) -> bool:
+ return (
+ self.kv_cache_spec.model_version == "deepseek_v4"
+ or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla"
+ or self.kv_cache_spec.compress_ratio > 1
+ )
+
+ def cache_blocks(self, request: Request, num_tokens: int) -> None:
+ super().cache_blocks(request, num_tokens)
+ if (
+ not self._should_protect_prompt_blocks()
+ or num_tokens < request.num_prompt_tokens
+ or request.num_prompt_tokens <= 1
+ ):
+ return
+
+ max_cache_hit_length = request.num_prompt_tokens - 1
+ aligned_cache_hit_length = (
+ max_cache_hit_length
+ // self.cache_alignment_tokens
+ * self.cache_alignment_tokens
+ )
+ num_hit_blocks = aligned_cache_hit_length // self.block_size
+ if num_hit_blocks == 0:
+ return
+
+ self._protect_prompt_blocks(
+ self.req_to_blocks[request.request_id][:num_hit_blocks]
+ )
+
+ def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
+ blocks = self.req_to_blocks[running_request_id]
+ num_common_blocks = 0
+ expected_ref_cnt = len(self.req_to_blocks)
+ for block in blocks:
+ ref_cnt = block.ref_cnt
+ if block.block_id in self._protected_prompt_block_ids:
+ ref_cnt -= 1
+ if ref_cnt == expected_ref_cnt:
+ num_common_blocks += 1
+ else:
+ break
+ return num_common_blocks
+
+
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
@@ -641,6 +758,42 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
return 0
+class SlidingWindowMLAManager(SlidingWindowManager):
+ """KV cache manager for DeepSeek V4's sliding-window MLA cache.
+
+ During decode, the live sliding window can move past the prompt boundary.
+ The blocks around the hybrid-aligned prompt boundary are still the suffix
+ needed for a future prefix-cache hit of the same prompt.
+ """
+
+ def cache_blocks(self, request: Request, num_tokens: int) -> None:
+ super().cache_blocks(request, num_tokens)
+ if not self.enable_caching or num_tokens < request.num_prompt_tokens:
+ return
+ if request.num_prompt_tokens <= 1:
+ return
+
+ max_cache_hit_length = request.num_prompt_tokens - 1
+ aligned_cache_hit_length = (
+ max_cache_hit_length
+ // self.cache_alignment_tokens
+ * self.cache_alignment_tokens
+ )
+ if aligned_cache_hit_length <= 0:
+ return
+
+ aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size
+ last_full_prompt_block = max_cache_hit_length // self.block_size
+ contiguous_blocks = cdiv(self.sliding_window - 1, self.block_size)
+ first_protected_block = max(0, aligned_num_hit_blocks - contiguous_blocks)
+ last_protected_block = max(aligned_num_hit_blocks, last_full_prompt_block)
+ blocks = self.req_to_blocks[request.request_id]
+ protected_blocks = blocks[
+ first_protected_block : min(last_protected_block, len(blocks))
+ ]
+ self._protect_prompt_blocks(protected_blocks)
+
+
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None:
super().__init__(kv_cache_spec, **kwargs)
@@ -1124,6 +1277,7 @@ def __init__(
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
+ max_model_len: int | None = None,
):
super().__init__(
kv_cache_spec,
@@ -1132,6 +1286,7 @@ def __init__(
kv_cache_group_id,
dcp_world_size,
pcp_world_size,
+ max_model_len=max_model_len,
)
sink_len = kv_cache_spec.sink_len
assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0
@@ -1142,9 +1297,9 @@ def __init__(
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
TQFullAttentionSpec: FullAttentionManager,
- MLAAttentionSpec: FullAttentionManager,
+ MLAAttentionSpec: MLAAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
- SlidingWindowMLASpec: SlidingWindowManager,
+ SlidingWindowMLASpec: SlidingWindowMLAManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
@@ -1159,6 +1314,7 @@ def get_manager_for_kv_cache_spec(
**kwargs,
) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
+ kwargs["max_model_len"] = max_model_len
# SlidingWindow / ChunkedLocalAttention managers recycle blocks across
# chunks; the runtime admission cap must match the recycling-aware bound
# the startup pool sizer uses (single source of truth: the spec method).
diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py
index 1541b24deaaf..5d3fe0dbd166 100644
--- a/vllm/v1/executor/ray_utils.py
+++ b/vllm/v1/executor/ray_utils.py
@@ -273,6 +273,22 @@ def assert_ray_available():
)
+def _warn_if_insufficient_cluster_devices(
+ parallel_config: ParallelConfig, device_str: str
+) -> None:
+ num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
+ if parallel_config.world_size > num_devices_in_cluster:
+ logger.warning(
+ "The requested distributed world size (%d) exceeds the total "
+ "number of available %ss (%d) in the Ray cluster. This may result "
+ "in Ray placement group allocation failures. Check `ray status` "
+ "and `ray list nodes`, or reduce tensor/pipeline parallel size.",
+ parallel_config.world_size,
+ device_str,
+ num_devices_in_cluster,
+ )
+
+
def _verify_bundles(
placement_group: "PlacementGroup",
parallel_config: ParallelConfig,
@@ -549,21 +565,6 @@ def initialize_ray_cluster(
if os.environ.get("RAY_USAGE_STATS_ENABLED", "0") != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
- # Prevalidate GPU requirements before Ray processing
- if current_platform.is_cuda() and parallel_config.world_size > 1:
- available_gpus = current_platform.device_count()
- if parallel_config.world_size > available_gpus:
- logger.warning(
- "Tensor parallel size (%d) exceeds available GPUs (%d). "
- "This may result in Ray placement group allocation failures. "
- "Consider reducing tensor_parallel_size to %d or less, "
- "or ensure your Ray cluster has %d GPUs available.",
- parallel_config.world_size,
- available_gpus,
- available_gpus,
- parallel_config.world_size,
- )
-
if ray.is_initialized():
logger.info("Ray is already initialized. Skipping Ray initialization.")
elif current_platform.is_rocm() or current_platform.is_xpu():
@@ -589,6 +590,9 @@ def initialize_ray_cluster(
f"current platform {current_platform.device_name} does not support ray."
)
+ if parallel_config.world_size > 1:
+ _warn_if_insufficient_cluster_devices(parallel_config, device_str)
+
# Create or get the placement group for worker processes
if parallel_config.placement_group:
current_placement_group = parallel_config.placement_group
@@ -619,17 +623,6 @@ def initialize_ray_cluster(
)
else:
logger.info("No current placement group found. Creating a new placement group.")
- num_devices_in_cluster = ray.cluster_resources().get(device_str, 0)
- # Log a warning message and delay resource allocation failure response.
- # Avoid immediate rejection to allow user-initiated placement group
- # created and wait cluster to be ready
- if parallel_config.world_size > num_devices_in_cluster:
- logger.warning(
- "The number of required %ss exceeds the total "
- "number of available %ss in the placement group.",
- device_str,
- device_str,
- )
# Create a new placement group
placement_group_specs: list[dict[str, float]] = [
{device_str: 1.0} for _ in range(parallel_config.world_size)