diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py index 0cea46d6284f..01f030836527 100644 --- a/tests/kernels/attention/test_deepgemm_attention.py +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -10,10 +10,12 @@ _ceil_to_ue8m0, calc_diff, fp8_fp4_mqa_logits, - fp8_fp4_paged_mqa_logits, get_num_sms, get_paged_mqa_logits_metadata, ) +from vllm.utils.deep_gemm import ( + fp8_fp4_paged_mqa_logits as fp8_paged_mqa_logits, +) from vllm.utils.import_utils import has_deep_gemm from vllm.utils.math_utils import cdiv @@ -90,10 +92,64 @@ def _ref_fp8_mqa_logits( return logits +def _supports_deepgemm_optimized_mqa_logits() -> bool: + return current_platform.is_cuda() and ( + current_platform.is_device_capability(90) + or current_platform.is_device_capability_family(100) + ) + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_torch_path(): + torch.manual_seed(0) + + seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32 + q = torch.randn( + seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = (torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3) + cu_seqlen_ke = torch.minimum( + torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + logits = fp8_fp4_mqa_logits( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + ref_logits = ref_logits.masked_fill(~valid, float("-inf")) + + assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits)) + finite = torch.isfinite(ref_logits) + assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4 + + @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") @pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" + not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only" ) @pytest.mark.parametrize("clean_logits", [True, False]) def test_deepgemm_fp8_mqa_logits(clean_logits: bool): @@ -150,7 +206,7 @@ def test_deepgemm_fp8_mqa_logits(clean_logits: bool): assert diff < 1e-3, f"{diff=}" -def _ref_fp8_fp4_paged_mqa_logits( +def _ref_fp8_paged_mqa_logits( q: torch.Tensor, kv_cache: torch.Tensor, weights: torch.Tensor, @@ -203,12 +259,10 @@ def _ref_fp8_fp4_paged_mqa_logits( @pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") @pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") @pytest.mark.skipif( - not current_platform.has_device_capability(90), reason="SM90 and SM100 only" + not _supports_deepgemm_optimized_mqa_logits(), reason="SM90 and SM100 only" ) -def test_deepgemm_fp8_fp4_paged_mqa_logits(): - # NOTE: clean_logits=True is incompatible with the 2D context_lens - # required by csrc/apis/attention.hpp; only the False path is exercised. - clean_logits = False +@pytest.mark.parametrize("clean_logits", [True, False]) +def test_deepgemm_fp8_paged_mqa_logits(clean_logits: bool): torch.manual_seed(0) random.seed(0) @@ -260,29 +314,24 @@ def test_deepgemm_fp8_fp4_paged_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - # deep_gemm paged MQA logits requires 2D context_lens of - # shape (B, next_n) (csrc/apis/attention.hpp:332-335); - # see indexer.py:607-608. For each batch/next_n token, the - # effective context length is context_lens[b] - next_n + j + 1. - next_n_arange = torch.arange(next_n, device="cuda", dtype=torch.int32) - context_lens_2d = ( - context_lens.unsqueeze(-1) - next_n + 1 + next_n_arange - ).contiguous() + deepgemm_context_lens = ( + context_lens[:, None].expand(-1, next_n).contiguous() + ) schedule_metadata = get_paged_mqa_logits_metadata( - context_lens_2d, blocksize, get_num_sms() + deepgemm_context_lens, blocksize, get_num_sms() ) - logits = fp8_fp4_paged_mqa_logits( + logits = fp8_paged_mqa_logits( (q_fp8, None), kv_cache_fp8, weights, - context_lens_2d, + deepgemm_context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=clean_logits, ) - ref_logits = _ref_fp8_fp4_paged_mqa_logits( + ref_logits = _ref_fp8_paged_mqa_logits( q, kv_cache, weights, diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index ebc3256b548f..a5d2c33b47eb 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -29,10 +29,15 @@ ) from vllm.model_executor.layers.fused_moe.config import ( FUSED_MOE_UNQUANTIZED_CONFIG, + FusedMoEConfig, + FusedMoEParallelConfig, + RoutingMethodType, int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, + mxfp4_w4a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + MarlinExperts, batched_fused_marlin_moe, fused_marlin_moe, ) @@ -1007,6 +1012,61 @@ def test_fused_marlin_moe( torch.testing.assert_close(marlin_output, torch_output, atol=4e-2, rtol=0) +def test_marlin_experts_apply_forwards_gemm1_clamp_limit(monkeypatch): + captured: dict[str, float | None] = {} + + def fake_fused_marlin_moe(**kwargs): + captured["clamp_limit"] = kwargs.get("clamp_limit") + kwargs["output"].zero_() + + monkeypatch.setattr( + "vllm.model_executor.layers.fused_moe.fused_marlin_moe." + "fused_marlin_moe", + fake_fused_marlin_moe, + ) + + clamp_limit = 10.0 + moe_config = FusedMoEConfig( + num_experts=2, + experts_per_token=1, + hidden_dim=16, + intermediate_size_per_partition=8, + num_local_experts=2, + num_logical_experts=2, + activation=MoEActivation.SILU, + device="cpu", + routing_method=RoutingMethodType.Default, + moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(), + in_dtype=torch.bfloat16, + ) + quant_config = mxfp4_w4a16_moe_quant_config( + w1_scale=torch.empty(0), + w2_scale=torch.empty(0), + gemm1_clamp_limit=clamp_limit, + ) + experts = MarlinExperts(moe_config=moe_config, quant_config=quant_config) + + experts.apply( + output=torch.empty((1, 16), dtype=torch.bfloat16), + hidden_states=torch.empty((1, 16), dtype=torch.bfloat16), + w1=torch.empty((2, 1, 1), dtype=torch.int32), + w2=torch.empty((2, 1, 1), dtype=torch.int32), + topk_weights=torch.ones((1, 1), dtype=torch.float32), + topk_ids=torch.zeros((1, 1), dtype=torch.int32), + activation=MoEActivation.SILU, + global_num_experts=2, + expert_map=None, + a1q_scale=None, + a2_scale=None, + workspace13=torch.empty(0), + workspace2=torch.empty(0), + expert_tokens_meta=None, + apply_router_weight_on_input=False, + ) + + assert captured["clamp_limit"] == clamp_limit + + @pytest.mark.flaky(reruns=2) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.parametrize("m", [1, 256]) diff --git a/tests/models/test_deepseek_v4_mega_moe.py b/tests/models/test_deepseek_v4_mega_moe.py index 304f044868a3..7da906f58446 100644 --- a/tests/models/test_deepseek_v4_mega_moe.py +++ b/tests/models/test_deepseek_v4_mega_moe.py @@ -9,6 +9,7 @@ from vllm.model_executor.models.deepseek_v4 import ( DeepseekV4MegaMoEExperts, _stage_deepseek_v4_mega_moe_inputs, + _use_deepseek_v4_mega_moe, make_deepseek_v4_expert_params_mapping, ) from vllm.platforms import current_platform @@ -19,6 +20,52 @@ ) +def _make_mega_moe_config( + *, + enable_expert_parallel: bool = True, + moe_backend: str = "auto", +): + return SimpleNamespace( + parallel_config=SimpleNamespace( + enable_expert_parallel=enable_expert_parallel + ), + kernel_config=SimpleNamespace(moe_backend=moe_backend), + ) + + +def test_deepseek_v4_mega_moe_selection_preserves_kernel_config(monkeypatch): + from vllm import envs + + monkeypatch.delenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", raising=False) + envs.disable_envs_cache() + + assert _use_deepseek_v4_mega_moe( + _make_mega_moe_config(moe_backend="deep_gemm_mega_moe") + ) + assert not _use_deepseek_v4_mega_moe(_make_mega_moe_config()) + with pytest.raises(NotImplementedError, match="requires expert parallel"): + _use_deepseek_v4_mega_moe( + _make_mega_moe_config( + enable_expert_parallel=False, + moe_backend="deep_gemm_mega_moe", + ) + ) + + +def test_deepseek_v4_mega_moe_selection_env_override(monkeypatch): + from vllm import envs + + monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "1") + envs.disable_envs_cache() + assert _use_deepseek_v4_mega_moe(_make_mega_moe_config()) + + monkeypatch.setenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0") + envs.disable_envs_cache() + assert not _use_deepseek_v4_mega_moe( + _make_mega_moe_config(moe_backend="deep_gemm_mega_moe") + ) + + def test_deepseek_v4_mega_moe_expert_mapping(): mapping = make_deepseek_v4_expert_params_mapping(2) @@ -46,7 +93,8 @@ def test_deepseek_v4_mega_moe_ue8m0_uint8_to_float(): def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): vllm_config = SimpleNamespace( - scheduler_config=SimpleNamespace(max_num_batched_tokens=4) + scheduler_config=SimpleNamespace(max_num_batched_tokens=4), + compilation_config=SimpleNamespace(static_forward_context={}), ) experts = DeepseekV4MegaMoEExperts( vllm_config, @@ -111,7 +159,10 @@ def test_deepseek_v4_mega_moe_weight_loader_uses_ep_expert_ownership(): reason="DeepSeek V4 MegaMoE fused input staging requires CUDA.", ) def test_deepseek_v4_mega_moe_fused_input_staging_is_bitwise_exact(): - from vllm.third_party.deep_gemm.utils import per_token_cast_to_fp8 + per_token_cast_to_fp8 = pytest.importorskip( + "deep_gemm.utils", + reason="DeepGEMM helper package is required for FP8 staging parity.", + ).per_token_cast_to_fp8 device = torch.device("cuda") num_tokens = 7 diff --git a/tests/models/test_deepseek_v4_pp.py b/tests/models/test_deepseek_v4_pp.py new file mode 100644 index 000000000000..7c0ae5dfd725 --- /dev/null +++ b/tests/models/test_deepseek_v4_pp.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.model_executor.models.deepseek_v4 import DeepseekV4ForCausalLM +from vllm.model_executor.models.interfaces import supports_pp + + +def test_deepseek_v4_declares_pipeline_parallel_support(): + assert supports_pp(DeepseekV4ForCausalLM) diff --git a/tests/quantization/test_fp8_scale_parameter.py b/tests/quantization/test_fp8_scale_parameter.py new file mode 100644 index 000000000000..c95cdbb98be2 --- /dev/null +++ b/tests/quantization/test_fp8_scale_parameter.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.model_executor.parameter as parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + create_fp8_scale_parameter, +) +from vllm.model_executor.parameter import BlockQuantScaleParameter + + +@pytest.mark.skipif( + not hasattr(torch, "float8_e8m0fnu"), + reason="torch does not expose float8_e8m0fnu", +) +def test_create_fp8_scale_parameter_initializes_e8m0(monkeypatch): + monkeypatch.setattr(parameter, "get_tensor_model_parallel_rank", lambda: 0) + monkeypatch.setattr(parameter, "get_tensor_model_parallel_world_size", lambda: 1) + + scale = create_fp8_scale_parameter( + BlockQuantScaleParameter, + output_partition_sizes=[128], + input_size_per_partition=128, + block_size=[128, 128], + weight_loader=None, + scale_dtype=torch.float8_e8m0fnu, + ) + + assert scale.dtype == torch.float8_e8m0fnu + raw_scale = scale.data.view(torch.uint8) + assert torch.equal(raw_scale, torch.zeros_like(raw_scale)) diff --git a/tests/quantization/test_mxfp4.py b/tests/quantization/test_mxfp4.py new file mode 100644 index 000000000000..6e6792e60cae --- /dev/null +++ b/tests/quantization/test_mxfp4.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +def test_mxfp4_e8m0_scale_loading_preserves_raw_bytes(): + from types import SimpleNamespace + + import pytest + import torch + + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is None: + pytest.skip("torch does not expose float8_e8m0fnu") + + layer = object.__new__(FusedMoE) + layer.moe_config = SimpleNamespace(is_act_and_mul=True) + + expert_data = torch.zeros((4, 2), dtype=torch.uint8) + loaded_scale = torch.tensor( + [[0.0078125, 0.015625], [0.5, 1.0]], + dtype=e8m0_dtype, + ) + + layer._load_w13( + expert_data=expert_data, + shard_dim=0, + shard_id="w1", + loaded_weight=loaded_scale, + tp_rank=0, + ) + + torch.testing.assert_close( + expert_data[:2], + loaded_scale.view(torch.uint8), + rtol=0, + atol=0, + ) diff --git a/tests/reasoning/test_deepseekv3_reasoning_parser.py b/tests/reasoning/test_deepseekv3_reasoning_parser.py index f5b37194f927..a013cf1a8775 100644 --- a/tests/reasoning/test_deepseekv3_reasoning_parser.py +++ b/tests/reasoning/test_deepseekv3_reasoning_parser.py @@ -4,16 +4,35 @@ import pytest from transformers import AutoTokenizer +from vllm.config.reasoning import ReasoningConfig from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning import ReasoningParserManager from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser -from vllm.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from vllm.reasoning.deepseek_v3_reasoning_parser import ( + DeepSeekV3ReasoningParser, + DeepSeekV3ReasoningWithThinkingParser, +) from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1" +class FakeReasoningTokenizer: + + def get_vocab(self) -> dict[str, int]: + return {"": 100, "": 101} + + def encode( + self, + text: str, + add_special_tokens: bool = False, + **kwargs, + ) -> list[int]: + assert add_special_tokens is False + return [self.get_vocab()[text]] + + @pytest.fixture(scope="module") def tokenizer(): return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME) @@ -37,7 +56,22 @@ def test_parser_selection(tokenizer, thinking, expected_parser_type): def test_deepseek_v4_reasoning_parser_alias(): parser_cls = ReasoningParserManager.get_reasoning_parser("deepseek_v4") - assert parser_cls is DeepSeekV3ReasoningParser + assert parser_cls is DeepSeekV3ReasoningWithThinkingParser + + +def test_deepseek_v4_auto_reasoning_config_initializes_budget_tokens(monkeypatch): + monkeypatch.setattr( + "vllm.config.reasoning.cached_tokenizer_from_config", + lambda model_config: FakeReasoningTokenizer(), + ) + + config = ReasoningConfig(reasoning_parser="deepseek_v4") + + config.initialize_token_ids(model_config=None) + + assert config.enabled is True + assert config.reasoning_start_token_ids == [100] + assert config.reasoning_end_token_ids == [101] def test_identity_reasoning_parser_basic(tokenizer): diff --git a/tests/tokenizers_/test_deepseek_v4.py b/tests/tokenizers_/test_deepseek_v4.py index 358732eabf40..0a3253957add 100644 --- a/tests/tokenizers_/test_deepseek_v4.py +++ b/tests/tokenizers_/test_deepseek_v4.py @@ -8,8 +8,14 @@ import pytest from vllm.entrypoints.chat_utils import parse_chat_messages +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ChatMessage, +) +from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.renderers.registry import RENDERER_REGISTRY from vllm.tokenizers.deepseek_v4 import get_deepseek_v4_tokenizer +from vllm.tokenizers.deepseek_v4_encoding import encode_arguments_to_dsml from vllm.tokenizers.registry import TokenizerRegistry FIXTURES_DIR = Path(__file__).parent / "fixtures" / "deepseek_v4" @@ -96,6 +102,130 @@ def test_deepseek_v4_enables_thinking_with_compatible_kwargs(kwargs): assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") +def test_deepseek_v4_honors_official_thinking_request_field(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + prompt = _tokenizer().apply_chat_template( + request.messages, + tokenize=False, + **chat_kwargs, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert prompt == ("<|begin▁of▁sentence|><|User|>Hello<|Assistant|>") + + +def test_deepseek_v4_defaults_to_official_thinking_for_openai_request(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_preserves_official_reasoning_content_alias(): + messages = [ + {"role": "user", "content": "Q1"}, + {"role": "assistant", "reasoning_content": "because", "content": "A1"}, + {"role": "user", "content": "Q2"}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + assert conversation[1]["reasoning"] == "because" + assert conversation[1]["reasoning_content"] == "because" + + +def test_deepseek_v4_response_messages_expose_reasoning_content_alias(): + message = ChatMessage(role="assistant", reasoning="because", content="answer") + delta = DeltaMessage(reasoning="because") + + assert message.reasoning_content == "because" + assert delta.reasoning_content == "because" + assert ( + ChatMessage( + role="assistant", + reasoning_content="because", + content="answer", + ).reasoning + == "because" + ) + + +def test_deepseek_v4_preserves_official_prefix_assistant_message(): + messages = [ + {"role": "user", "content": "Please write quick sort code"}, + {"role": "assistant", "content": "```python\n", "prefix": True}, + ] + + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert conversation[1]["prefix"] is True + assert conversation[1]["wo_eos"] is True + assert prompt.endswith("<|Assistant|>```python\n") + assert not prompt.endswith("<|end▁of▁sentence|>") + + +def test_deepseek_v4_thinking_ignores_sampling_controls(): + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 1.0 + assert sampling_params.top_p == 1.0 + assert sampling_params.top_k == 0 + assert sampling_params.presence_penalty == 0.0 + assert sampling_params.frequency_penalty == 0.0 + + def test_deepseek_v4_uses_v4_tool_prompt_from_request_tools(): tools = [ { @@ -183,6 +313,66 @@ def test_deepseek_v4_renders_parsed_history_tool_arguments(): assert 'parameter name="arguments"' not in prompt +@pytest.mark.parametrize( + ("tool_call", "expected_parameter"), + [ + ({"name": "refresh", "arguments": None}, None), + ({"name": "refresh"}, None), + ({"name": "refresh", "arguments": ""}, None), + ( + {"name": "refresh", "arguments": '{"target": "cache"}'}, + '<|DSML|parameter name="target" string="true">cache', + ), + ( + {"name": "refresh", "arguments": {"target": "cache"}}, + '<|DSML|parameter name="target" string="true">cache', + ), + ], +) +def test_deepseek_v4_encodes_empty_history_tool_arguments( + tool_call, expected_parameter +): + prompt = encode_arguments_to_dsml(tool_call) + + if expected_parameter is None: + assert prompt == "" + else: + assert expected_parameter in prompt + + +def test_deepseek_v4_renders_openai_history_tool_call_with_null_arguments(): + messages = [ + {"role": "user", "content": "Refresh state"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "refresh", + "arguments": None, + }, + } + ], + }, + ] + conversation, _, _ = parse_chat_messages( + messages, + _model_config(), + content_format="string", + ) + + prompt = _tokenizer().apply_chat_template( + conversation=conversation, + messages=messages, + tokenize=False, + ) + + assert '<|DSML|invoke name="refresh">' in prompt + assert "<|DSML|parameter" not in prompt + + @pytest.mark.parametrize("reasoning_effort", ["minimal", "low", "medium", "high"]) def test_deepseek_v4_accepts_openai_reasoning_effort_values(reasoning_effort): prompt = _tokenizer().apply_chat_template( @@ -288,3 +478,136 @@ def test_deepseek_v4_matches_reference_golden_fixtures(case_id, kwargs): expected = (FIXTURES_DIR / f"test_output_{case_id}.txt").read_text() assert prompt == expected + + +@pytest.mark.parametrize( + "model", + [ + "deepseek-ai/DeepSeek-V4-Flash", + "deepseek-ai/DeepSeek-V4-Pro", + ], +) +def test_deepseek_v4_official_api_defaults_to_thinking_for_v4_family(model): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": model, + "messages": [{"role": "user", "content": "Hello"}], + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + + +def test_deepseek_v4_official_api_uses_model_config_for_family_detection(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "local-ds4-alias", + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.2, + } + ) + model_config = SimpleNamespace( + hf_config=SimpleNamespace(model_type="deepseek_v4", architectures=[]), + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs, + model_config=model_config, + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + model_config=model_config, + ) + + assert chat_kwargs["thinking"] is True + assert chat_kwargs["enable_thinking"] is True + assert sampling_params.temperature == 1.0 + + +def test_deepseek_v4_official_api_sampling_override_can_be_disabled(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-V4-Flash", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "deepseek_v4_sampling_override": False, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 + + +def test_deepseek_v4_official_api_sampling_override_is_v4_only(): + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + + request = ChatCompletionRequest.model_validate( + { + "model": "deepseek-ai/DeepSeek-R1", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": {"type": "enabled"}, + "temperature": 0.2, + "top_p": 0.3, + "top_k": 4, + "min_p": 0.05, + "presence_penalty": 1.5, + "frequency_penalty": 1.25, + } + ) + chat_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params(None, "auto").chat_template_kwargs + ) + + sampling_params = request.to_sampling_params( + 16, + {}, + chat_template_kwargs=chat_kwargs, + ) + + assert "thinking" not in chat_kwargs + assert "enable_thinking" not in chat_kwargs + assert sampling_params.temperature == 0.2 + assert sampling_params.top_p == 0.3 + assert sampling_params.top_k == 4 + assert sampling_params.min_p == 0.05 + assert sampling_params.presence_penalty == 1.5 + assert sampling_params.frequency_penalty == 1.25 diff --git a/tests/tools/test_compare_vllm_http_logprobs_oracle.py b/tests/tools/test_compare_vllm_http_logprobs_oracle.py new file mode 100644 index 000000000000..729c76659d53 --- /dev/null +++ b/tests/tools/test_compare_vllm_http_logprobs_oracle.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import importlib.util +from pathlib import Path + +SCRIPT_PATH = ( + Path(__file__).parents[2] / "tools" / "compare_vllm_http_logprobs_oracle.py" +) +spec = importlib.util.spec_from_file_location( + "compare_vllm_http_logprobs_oracle", SCRIPT_PATH +) +assert spec is not None +oracle_compare = importlib.util.module_from_spec(spec) +assert spec.loader is not None +spec.loader.exec_module(oracle_compare) + + +def _response(tokens, top_logprobs): + token_ids = [ + int(token.split(":", 1)[1]) + for token in tokens + if isinstance(token, str) and token.startswith("token_id:") + ] + return { + "choices": [ + { + "text": "", + "logprobs": { + "tokens": tokens, + "token_logprobs": [-0.1 * (i + 1) for i in range(len(tokens))], + "top_logprobs": top_logprobs, + }, + "token_ids": token_ids, + "prompt_token_ids": [1, 2, 3], + } + ], + "usage": {"prompt_tokens": 3, "completion_tokens": len(tokens)}, + } + + +def test_compare_response_accepts_identical_top_logprobs(): + top_logprobs = [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:11": -0.2, "token_id:21": -1.2}, + ] + report = oracle_compare.compare_response( + "case0", + _response(["token_id:10", "token_id:11"], top_logprobs), + _response(["token_id:10", "token_id:11"], top_logprobs), + top_n=2, + ) + + assert report["tokens_match"] is True + assert report["first_token_mismatch"] is None + assert report["top1_matches"] == 2 + assert report["topk_overlap_mean"] == 1.0 + assert report["max_common_logprob_abs_error"] == 0.0 + + +def test_compare_response_reports_first_generated_token_divergence(): + oracle = _response( + ["token_id:10", "token_id:11"], + [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:11": -0.2, "token_id:21": -1.2}, + ], + ) + actual = _response( + ["token_id:10", "token_id:99"], + [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:99": -0.2, "token_id:21": -1.2}, + ], + ) + + report = oracle_compare.compare_response("case0", oracle, actual, top_n=2) + + assert report["tokens_match"] is False + assert report["first_token_mismatch"] == { + "step": 1, + "oracle": "token_id:11", + "actual": "token_id:99", + } + assert report["matching_prefix_tokens"] == 1 + assert report["top1_matches"] == 1 + + +def test_compare_response_can_decode_oracle_token_id_keys(): + normalizer = oracle_compare.TokenNormalizer( + lambda token_id: {10: '","', 11: "title", 20: " What"}[token_id] + ) + oracle = _response( + ["token_id:10", "token_id:11"], + [ + {"token_id:10": -0.1, "token_id:20": -1.0}, + {"token_id:11": -0.2, "token_id:20": -1.2}, + ], + ) + actual = _response( + ['","', "title"], + [ + {'","': -0.11, " What": -1.1}, + {"title": -0.19, " What": -1.3}, + ], + ) + + report = oracle_compare.compare_response( + "case0", oracle, actual, top_n=2, normalizer=normalizer + ) + + assert report["tokens_match"] is True + assert report["top1_matches"] == 2 + assert report["topk_overlap_mean"] == 1.0 + assert report["max_common_logprob_abs_error"] == 0.10000000000000009 diff --git a/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py new file mode 100644 index 000000000000..dbdc6f4e689c --- /dev/null +++ b/tests/v1/attention/test_deepseek_v4_sparse_mla_reference.py @@ -0,0 +1,3162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Correctness tests for the DeepSeek V4 Triton sparse MLA path and reference oracle.""" + +from types import SimpleNamespace + +import pytest +import torch + +from vllm.config.compilation import ( + CompilationConfig, + CompilationMode, + CUDAGraphMode, +) +from vllm.model_executor.layers import ( + deepseek_v4_attention as deepseek_v4_attention_module, +) +from vllm.model_executor.layers.deepseek_v4_attention import ( + _allocate_deepseek_v4_wo_a_output, + _deepseek_v4_fp8_einsum_config, + _sparse_mla_prefill_workspace_bounds, + deepseek_v4_fp8_einsum, +) +from vllm.utils.deep_gemm import fp8_einsum +from vllm.v1.attention.backend import AttentionCGSupport +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseMetadataBuilder, +) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_triton_sparse_mla_cudagraphs_if_enabled, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk, + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_gathered_sparse_mla_attention_chunk, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_gathered_sparse_mla_attention, + finish_materialized_sparse_mla_scores_with_sink, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + merge_sparse_mla_subset_with_sink, + merge_two_sparse_mla_subsets_with_sink, + sparse_mla_decode_head_block_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_reference import ( + accumulate_reference_attention_chunk, + finish_reference_attention_no_sink, + merge_reference_attention_with_sink, + new_reference_attention_state, + reference_attention_no_sink, + reference_sparse_mla_prefill, + sink_aware_reference_attention, +) +from vllm.v1.attention.backends.mla.sparse_swa import DeepseekSparseSWAMetadataBuilder +from vllm.v1.attention.ops.deepseek_v4_ops import ( + dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, +) +from vllm.v1.attention.ops.deepseek_v4_ops import fp8_einsum as fp8_einsum_module +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_sm12x_fp8_einsum, +) +from vllm.v1.kv_cache_interface import MLAAttentionSpec, SlidingWindowMLASpec + +_FP8_DIM = 448 +_ROPE_DIM = 64 +_SCALE_DIM = 8 +_TOKEN_DATA_SIZE = _FP8_DIM + _ROPE_DIM * 2 + + +class _FakeWorkspaceManager: + def get_simultaneous(self, *specs): + return tuple(torch.empty(shape, dtype=dtype) for shape, dtype in specs) + + +def _assert_fp8_einsum_close(actual: torch.Tensor, expected: torch.Tensor) -> None: + # The Triton path and DeepGEMM reference both accumulate in FP32, but + # their reduction orders are not bit-identical before the final BF16 store. + torch.testing.assert_close(actual.float(), expected.float(), rtol=5e-2, atol=3e-4) + + +def test_deepseek_v4_fp8_einsum_is_piecewise_split_op() -> None: + assert "vllm::deepseek_v4_fp8_einsum" in CompilationConfig._attention_ops + + +def test_wo_a_output_allocation_uses_empty_during_compile(monkeypatch) -> None: + class FailingWorkspaceManager: + def get_simultaneous(self, *args, **kwargs): + raise AssertionError("compiled allocation must not grow workspace") + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: FailingWorkspaceManager(), + ) + monkeypatch.setattr(torch.compiler, "is_compiling", lambda: True) + + output = _allocate_deepseek_v4_wo_a_output( + 2, + 3, + 5, + torch.bfloat16, + torch.device("cpu"), + ) + + assert output.shape == (2, 3, 5) + assert output.dtype == torch.bfloat16 + + +def test_wo_a_output_allocation_uses_workspace_outside_compile(monkeypatch) -> None: + captured = {} + + class FakeWorkspaceManager: + def get_simultaneous(self, *shapes_and_dtypes): + captured["request"] = shapes_and_dtypes + return [ + torch.empty(shape, dtype=dtype) for shape, dtype in shapes_and_dtypes + ] + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: FakeWorkspaceManager(), + ) + monkeypatch.setattr(torch.compiler, "is_compiling", lambda: False) + + output = _allocate_deepseek_v4_wo_a_output( + 2, + 3, + 5, + torch.bfloat16, + torch.device("cpu"), + ) + + assert captured["request"] == (((2, 3, 5), torch.bfloat16),) + assert output.shape == (2, 3, 5) + assert output.dtype == torch.bfloat16 + + +def test_dummy_attention_impl_reserves_prefill_workspace(monkeypatch) -> None: + class FakeMLAAttn: + def __init__(self) -> None: + self.reserved = False + + def _reserve_prefill_workspace(self) -> None: + self.reserved = True + + def __call__(self, *args, **kwargs) -> None: + raise AssertionError("dummy run must not execute real attention") + + mla_attn = FakeMLAAttn() + layer = object.__new__( + deepseek_v4_attention_module.DeepseekV4MultiHeadLatentAttentionWrapper + ) + layer.q_lora_rank = 2 + layer.head_dim = 4 + layer.n_local_heads = 2 + layer.padded_heads = 64 + layer.indexer = None + layer.compressor = None + layer.wq_b = lambda qr: torch.ones(qr.shape[0], 8) + layer.q_norm = SimpleNamespace(weight=SimpleNamespace(data=torch.empty(0))) + layer.kv_norm = SimpleNamespace(weight=SimpleNamespace(data=torch.empty(0))) + layer.eps = 1e-6 + layer.mla_attn = mla_attn + layer.attn_gemm_parallel_execute = lambda hidden_states: ( + torch.zeros(hidden_states.shape[0], 6), + None, + None, + None, + ) + layer._fused_qnorm_rope_kv_insert = lambda *args, **kwargs: None + + monkeypatch.setattr( + deepseek_v4_attention_module, + "get_forward_context", + lambda: SimpleNamespace(attn_metadata=None), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fused_q_kv_rmsnorm", + lambda qr, kv, *args, **kwargs: (qr, kv), + ) + + out = torch.ones((3, 64, 4)) + layer.attention_impl( + hidden_states=torch.zeros((3, 6)), + positions=torch.arange(3), + out=out, + ) + + assert mla_attn.reserved is True + assert torch.count_nonzero(out) == 0 + + +def test_prefill_workspace_reservation_specs_match_forward_prefill_bounds( + monkeypatch, +) -> None: + attn = SimpleNamespace( + max_model_len=16_384, + max_num_batched_tokens=8192, + compress_ratio=4, + window_size=128, + head_dim=512, + num_heads=64, + topk_indices_buffer=torch.empty((8192, 2048), dtype=torch.int32), + indexer=None, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "is_triton_sparse_mla_enabled_for_platform", + lambda: True, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "triton_sparse_mla_query_chunk_size", + lambda: 256, + ) + + attention_cls = deepseek_v4_attention_module.DeepseekV4MLAAttention + specs = attention_cls._prefill_workspace_reservation_specs(attn) + + assert specs == ( + ((4, 12_415, 512), torch.bfloat16), + ((8192, 2176), torch.int32), + ((8192,), torch.int32), + ((256, 64), torch.float32), + ((256, 64), torch.float32), + ((256, 64, 512), torch.float32), + ) + + +def test_triton_sparse_mla_default_topk_chunk_size(monkeypatch) -> None: + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", raising=False) + + assert triton_sparse_mla_topk_chunk_size() == 512 + + +def test_sparse_mla_prefill_workspace_bounds_use_active_prefill_lengths() -> None: + seq_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32) + gather_lens_cpu = torch.tensor([15_000, 2_048], dtype=torch.int32) + + compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=4, + swa_only=False, + ) + + assert compressed_region_size == 3_750 + assert row_stride == 18_750 + + +def test_sparse_mla_prefill_workspace_bounds_for_swa_only() -> None: + seq_lens_cpu = torch.tensor([15_000], dtype=torch.int32) + gather_lens_cpu = torch.tensor([15_000], dtype=torch.int32) + + compressed_region_size, row_stride = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=1, + swa_only=True, + ) + + assert compressed_region_size == 0 + assert row_stride == 15_000 + + +@pytest.mark.parametrize( + ("num_decode_tokens", "expected_head_block_size"), + [ + (0, 1), + (1, 1), + (4, 1), + (5, 2), + (8, 2), + (15, 2), + (16, 4), + (32, 4), + ], +) +def test_triton_sparse_mla_decode_head_block_size( + num_decode_tokens: int, + expected_head_block_size: int, + monkeypatch, +) -> None: + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", raising=False) + + assert ( + sparse_mla_decode_head_block_size(num_decode_tokens) == expected_head_block_size + ) + + +@pytest.mark.parametrize("configured_head_block_size", ["1", "2", "4"]) +def test_triton_sparse_mla_decode_head_block_size_env_override( + configured_head_block_size: str, + monkeypatch, +) -> None: + monkeypatch.setenv( + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", + configured_head_block_size, + ) + + assert sparse_mla_decode_head_block_size(1) == int(configured_head_block_size) + assert sparse_mla_decode_head_block_size(32) == int(configured_head_block_size) + + +@pytest.mark.parametrize("configured_head_block_size", ["0", "3", "invalid"]) +def test_triton_sparse_mla_decode_head_block_size_ignores_invalid_env_override( + configured_head_block_size: str, + monkeypatch, +) -> None: + monkeypatch.setenv( + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", + configured_head_block_size, + ) + + assert sparse_mla_decode_head_block_size(8) == 2 + + +def test_swa_mtp_decode_triton_uses_global_swa_slots(monkeypatch) -> None: + captured: dict[str, torch.Tensor] = {} + + def fail_paged_attention_with_sink_multihead(**kwargs) -> None: + raise AssertionError("MTP SWA decode must use explicit SWA indices") + + def fake_accumulate_global_slots(**kwargs) -> None: + captured["slot_ids"] = kwargs["slot_ids"] + captured["lens"] = kwargs["lens"] + + def fake_finish_with_sink(*args, **kwargs) -> None: + kwargs["output"].zero_() + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: _FakeWorkspaceManager(), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fp8ds_paged_sparse_mla_attention_with_sink_multihead", + fail_paged_attention_with_sink_multihead, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead", + fake_accumulate_global_slots, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "finish_sparse_mla_attention_with_sink", + fake_finish_with_sink, + ) + + attention = SimpleNamespace( + num_heads=2, + scale=0.1, + attn_sink=torch.zeros(2, dtype=torch.float32), + ) + swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8) + swa_lens = torch.tensor([2, 3, 4, 2, 3, 4], dtype=torch.int32) + metadata = SimpleNamespace( + num_decodes=2, + num_decode_tokens=6, + decode_swa_lens=swa_lens, + decode_swa_indices=swa_indices, + seq_lens=torch.tensor([11, 22], dtype=torch.int32), + block_table=torch.empty((2, 4), dtype=torch.int32), + block_size=256, + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32), + ) + + deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_swa_decode_triton( + attention, + q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16), + swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8), + swa_metadata=metadata, + output=torch.empty((6, 2, 512), dtype=torch.bfloat16), + ) + + torch.testing.assert_close(captured["slot_ids"], swa_indices) + torch.testing.assert_close(captured["lens"], swa_lens) + + +def test_compressed_mtp_decode_triton_uses_global_swa_slots(monkeypatch) -> None: + captured: list[torch.Tensor] = [] + + def fail_matmul_decode(**kwargs) -> None: + raise AssertionError("MTP compressed decode must not stage paged SWA") + + def fail_direct_global_paged(**kwargs) -> None: + raise AssertionError("MTP compressed decode must not use paged SWA window") + + def fake_accumulate_global_slots(**kwargs) -> None: + captured.append(kwargs["slot_ids"]) + + def fake_finish_two_states(*args, **kwargs) -> None: + kwargs["output"].zero_() + + monkeypatch.setattr( + deepseek_v4_attention_module, + "current_workspace_manager", + lambda: _FakeWorkspaceManager(), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "dequantize_combined_sparse_mla_decode_kv", + fail_matmul_decode, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "fp8ds_global_paged_sparse_mla_attention_with_sink_multihead", + fail_direct_global_paged, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead", + fake_accumulate_global_slots, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "finish_two_sparse_mla_attention_states_with_sink", + fake_finish_two_states, + ) + + attention = SimpleNamespace( + num_heads=2, + scale=0.1, + attn_sink=torch.zeros(2, dtype=torch.float32), + compress_ratio=4, + ) + swa_indices = torch.arange(48, dtype=torch.int32).reshape(6, 1, 8) + topk_slot_ids = torch.arange(24, dtype=torch.int32).reshape(6, 1, 4) + swa_metadata = SimpleNamespace( + num_decodes=2, + num_decode_tokens=6, + decode_swa_lens=torch.full((6,), 3, dtype=torch.int32), + decode_swa_indices=swa_indices, + seq_lens=torch.tensor([11, 22], dtype=torch.int32), + block_table=torch.empty((2, 4), dtype=torch.int32), + block_size=256, + token_to_req_indices=torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32), + ) + + deepseek_v4_attention_module.DeepseekV4MLAAttention._forward_sparse_mla_compressed_decode_triton( + attention, + q=torch.empty((6, 1, 2, 512), dtype=torch.bfloat16), + compressed_k_cache=torch.empty((1, 64, 584), dtype=torch.uint8), + swa_k_cache=torch.empty((1, 256, 584), dtype=torch.uint8), + topk_indices=topk_slot_ids, + topk_lens=torch.full((6,), 4, dtype=torch.int32), + swa_metadata=swa_metadata, + attn_metadata=SimpleNamespace(block_size=256), + output=torch.empty((6, 2, 512), dtype=torch.bfloat16), + ) + + assert len(captured) == 2 + torch.testing.assert_close(captured[0], topk_slot_ids[:, 0]) + torch.testing.assert_close(captured[1], swa_indices) + + +@pytest.mark.parametrize( + ("capability_major", "expected_recipe", "expected_tma_aligned"), + [ + (9, (1, 128, 128), False), + (10, (1, 1, 128), True), + (12, (1, 128, 128), False), + ], +) +def test_deepseek_v4_fp8_einsum_config_for_sm12x( + capability_major: int, + expected_recipe: tuple[int, int, int], + expected_tma_aligned: bool, +) -> None: + assert _deepseek_v4_fp8_einsum_config(capability_major) == ( + expected_recipe, + expected_tma_aligned, + ) + + +def test_deepseek_v4_fp8_einsum_uses_sm12x_names() -> None: + assert hasattr(fp8_einsum_module, "_deepseek_v4_sm12x_fp8_einsum_kernel") + assert hasattr(fp8_einsum_module, "deepseek_v4_sm12x_fp8_einsum") + assert not hasattr(fp8_einsum_module, "_deepseek_v4_sm12_fp8_einsum_kernel") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("use_e8m0_scale", [False, True]) +def test_deepseek_v4_sm12x_triton_fp8_einsum_matches_deepgemm_reference( + use_e8m0_scale: bool, +) -> None: + if use_e8m0_scale and not hasattr(torch, "float8_e8m0fnu"): + pytest.skip("torch does not expose float8_e8m0fnu") + torch.manual_seed(0) + num_tokens = 17 + num_groups = 4 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.randn( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.randn( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + b = b_flat.view(num_groups, out_rank, hidden_size) + if use_e8m0_scale: + scale_choices = torch.tensor( + [0.00390625, 0.0078125, 0.015625, 0.03125], + device="cuda", + dtype=torch.float32, + ) + scale_indices = torch.randint( + 0, + len(scale_choices), + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + ) + b_scale_flat = scale_choices[scale_indices].to(torch.float8_e8m0fnu) + b_scale_ref_flat = b_scale_flat.to(torch.float32) + else: + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale_ref_flat = b_scale_flat + b_scale_ref = b_scale_ref_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum( + "bhr,hdr->bhd", + (a, a_scale), + (b, b_scale_ref), + expected, + recipe=recipe, + ) + deepseek_v4_fp8_einsum( + a, + a_scale, + b_flat, + b_scale_flat, + actual, + "bhr,hdr->bhd", + list(recipe), + ) + + _assert_fp8_einsum_close(actual, expected) + + +def test_deepseek_v4_fp8_einsum_slices_full_group_weight_for_tp( + monkeypatch, +) -> None: + captured: dict[str, torch.Tensor] = {} + num_tokens = 2 + local_groups = 4 + full_groups = 8 + out_rank = 512 + hidden_size = 4096 + recipe = [1, 128, 128] + + def fake_sm12_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + ) -> None: + captured["b"] = b + captured["b_scale"] = b_scale + + monkeypatch.setattr( + deepseek_v4_attention_module.current_platform, + "get_device_capability", + lambda: SimpleNamespace(major=12), + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "get_tensor_model_parallel_rank", + lambda: 1, + raising=False, + ) + monkeypatch.setattr( + deepseek_v4_attention_module, + "deepseek_v4_sm12x_fp8_einsum", + fake_sm12_fp8_einsum, + ) + + a = torch.empty( + (num_tokens, local_groups, hidden_size), + dtype=torch.float8_e4m3fn, + ) + a_scale = torch.empty( + (num_tokens, local_groups, hidden_size // 128), + dtype=torch.float32, + ) + b = torch.arange( + full_groups * out_rank * hidden_size, + dtype=torch.uint8, + ).view(torch.float8_e4m3fn) + b = b.view(full_groups * out_rank, hidden_size) + b_scale = torch.arange( + full_groups * (out_rank // 128) * (hidden_size // 128), + dtype=torch.float32, + ).view(full_groups * (out_rank // 128), hidden_size // 128) + out = torch.empty((num_tokens, local_groups, out_rank), dtype=torch.bfloat16) + + deepseek_v4_fp8_einsum( + a, + a_scale, + b, + b_scale, + out, + "bhr,hdr->bhd", + recipe, + ) + + expected_b = b.view(full_groups, out_rank, hidden_size)[local_groups:] + expected_b_scale = b_scale.view(full_groups, out_rank // 128, hidden_size // 128)[ + local_groups: + ] + assert torch.equal(captured["b"].view(torch.uint8), expected_b.view(torch.uint8)) + assert torch.equal(captured["b_scale"], expected_b_scale) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_deepseek_v4_sm12x_triton_fp8_einsum_primitive_matches_reference() -> None: + torch.manual_seed(0) + num_tokens = 17 + num_groups = 4 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.randn( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.randn( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + b = b_flat.view(num_groups, out_rank, hidden_size) + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, actual) + + _assert_fp8_einsum_close(actual, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_groups", [1, 2, 4]) +def test_deepseek_v4_sm12x_triton_fp8_einsum_supports_tp_local_group_counts( + num_groups: int, +) -> None: + torch.manual_seed(18 + num_groups) + num_tokens = 5 + hidden_size = 4096 + out_rank = 1024 + recipe = (1, 128, 128) + + a_backing = torch.randn( + (num_groups, num_tokens, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + a = a_backing.transpose(0, 1) + a_scale_backing = torch.empty( + (num_groups, num_tokens, hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + a_scale = a_scale_backing.transpose(0, 1) + b_flat = torch.randn( + (num_groups * out_rank, hidden_size), + device="cuda", + dtype=torch.bfloat16, + ).to(torch.float8_e4m3fn) + b = b_flat.view(num_groups, out_rank, hidden_size) + b_scale_flat = torch.empty( + (num_groups * (out_rank // 128), hidden_size // 128), + device="cuda", + dtype=torch.float32, + ).uniform_(0.01, 0.02) + b_scale = b_scale_flat.view(num_groups, out_rank // 128, hidden_size // 128) + expected = torch.empty( + (num_tokens, num_groups, out_rank), + device="cuda", + dtype=torch.bfloat16, + ) + actual = torch.empty_like(expected) + + fp8_einsum("bhr,hdr->bhd", (a, a_scale), (b, b_scale), expected, recipe=recipe) + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, actual) + + _assert_fp8_einsum_close(actual, expected) + + +def _masked_scores( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> torch.Tensor: + q_bhd = q[:, 0].float() if q.dim() == 4 else q.float() + scores = torch.einsum("bhd,btd->bht", q_bhd, kv.float()) * scale + return scores.masked_fill(~valid_tokens[:, None, :], float("-inf")) + + +def _golden_no_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + scores = _masked_scores(q, kv, valid_tokens, scale) + lse = torch.logsumexp(scores, dim=-1) + weights = torch.exp(scores - lse[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + output = torch.einsum("bht,btd->bhd", weights, kv.float()) + valid = valid_tokens.any(dim=-1) + output = torch.where( + valid[:, None, None], + output, + torch.zeros((), dtype=output.dtype, device=output.device), + ) + return output, lse + + +def _golden_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, +) -> torch.Tensor: + scores = _masked_scores(q, kv, valid_tokens, scale) + sink = attn_sink[None, :].float() + score_max = scores.amax(dim=-1) + merge_max = torch.maximum(score_max, sink) + + weights = torch.exp(scores - merge_max[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + + sink_weight = torch.exp(sink - merge_max) + sink_weight = torch.nan_to_num(sink_weight) + denom = weights.sum(dim=-1) + sink_weight + numerator = torch.einsum("bht,btd->bhd", weights, kv.float()) + return numerator / denom[:, :, None] + + +def _chunked_no_sink_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + chunk_size: int, +) -> tuple[torch.Tensor, torch.Tensor]: + q_bhd, max_score, denom, acc = new_reference_attention_state(q) + for chunk_start in range(0, kv.shape[1], chunk_size): + chunk_end = min(chunk_start + chunk_size, kv.shape[1]) + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=kv[:, chunk_start:chunk_end], + valid_tokens=valid_tokens[:, chunk_start:chunk_end], + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + return finish_reference_attention_no_sink(max_score, denom, acc) + + +def _write_fp8_ds_mla_token( + k_cache: torch.Tensor, + slot: int, + block_size: int, +) -> torch.Tensor: + block_idx = slot // block_size + block_offset = slot % block_size + + values = ( + (torch.arange(_FP8_DIM, device=k_cache.device, dtype=torch.float32) % 17) - 8 + ) / 16.0 + values = values + float(slot) / 32.0 + scale_exponents = torch.tensor( + [-2, -1, 0, 1, 2, -2, 1], + device=k_cache.device, + dtype=torch.float32, + ) + scales = torch.exp2(scale_exponents) + scale_per_dim = scales.repeat_interleave(64) + + fp8_values = (values / scale_per_dim).to(torch.float8_e4m3fn) + expected_nope = fp8_values.float() * scale_per_dim + rope = ( + torch.linspace(-1.0, 1.0, _ROPE_DIM, device=k_cache.device) + float(slot) / 16.0 + ).to(torch.bfloat16) + + flat_block = k_cache[block_idx].view(-1) + token_data_start = block_offset * _TOKEN_DATA_SIZE + token_scale_start = block_size * _TOKEN_DATA_SIZE + block_offset * _SCALE_DIM + flat_block[token_data_start : token_data_start + _FP8_DIM] = fp8_values.view( + torch.uint8 + ) + flat_block[token_data_start + _FP8_DIM : token_data_start + _TOKEN_DATA_SIZE] = ( + rope.view(torch.uint8) + ) + + encoded_scales = (scale_exponents.to(torch.int32) + 127).to(torch.uint8) + flat_block[token_scale_start : token_scale_start + encoded_scales.numel()] = ( + encoded_scales + ) + flat_block[ + token_scale_start + encoded_scales.numel() : token_scale_start + _SCALE_DIM + ] = 127 + + return torch.cat([expected_nope, rope.float()]).to(torch.bfloat16) + + +def _materialize_global_fp8_ds_mla_slots( + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> torch.Tensor: + gathered = torch.zeros( + *slot_ids.shape, + 512, + dtype=torch.bfloat16, + device=k_cache.device, + ) + for token_idx, row in enumerate(slot_ids.detach().cpu().tolist()): + for candidate_idx, slot in enumerate(row): + if slot >= 0: + gathered[token_idx, candidate_idx] = _write_fp8_ds_mla_token( + k_cache, + slot, + block_size, + ) + return gathered + + +def _materialize_paged_fp8_ds_mla_window( + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, +) -> torch.Tensor: + seq_lens_cpu = seq_lens.detach().cpu().tolist() + gather_lens_cpu = gather_lens.detach().cpu().tolist() + block_table_cpu = block_table.detach().cpu().tolist() + max_gather_len = max(gather_lens_cpu) + gathered = torch.zeros( + len(seq_lens_cpu), + max_gather_len, + 512, + dtype=torch.bfloat16, + device=k_cache.device, + ) + for token_idx, (seq_len, gather_len) in enumerate( + zip(seq_lens_cpu, gather_lens_cpu) + ): + start_pos = seq_len - gather_len + for gather_idx in range(gather_len): + logical_pos = start_pos + gather_idx + logical_block = logical_pos // block_size + block_offset = logical_pos % block_size + physical_block = block_table_cpu[token_idx][logical_block] + physical_slot = physical_block * block_size + block_offset + gathered[token_idx, gather_idx] = _write_fp8_ds_mla_token( + k_cache, + physical_slot, + block_size, + ) + return gathered + + +def test_reference_attention_no_sink_matches_logsumexp() -> None: + torch.manual_seed(0) + scale = 0.25 + q = torch.randn(3, 4, 5) + kv = torch.randn(3, 6, 5) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, False], + [False, False, False, False, False, False], + [True, False, True, True, True, False], + ], + dtype=torch.bool, + ) + output, lse = reference_attention_no_sink(q, kv, valid_tokens, scale) + expected_output, expected_lse = _golden_no_sink_attention( + q, + kv, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6) + + +def test_reference_attention_ignores_nan_kv_for_invalid_tokens() -> None: + torch.manual_seed(24) + q = torch.randn(2, 1, 3, 8) + kv = torch.randn(2, 4, 8) + kv[:, 2:] = float("nan") + valid_tokens = torch.tensor( + [[True, True, False, False], [True, False, False, False]], + dtype=torch.bool, + ) + + output, lse = reference_attention_no_sink( + q=q, + kv=kv, + valid_tokens=valid_tokens, + scale=0.125, + ) + + assert torch.isfinite(output).all() + assert torch.isfinite(lse).all() + + +def test_sink_aware_reference_attention_matches_dense_golden() -> None: + torch.manual_seed(1) + scale = 0.125 + q = torch.randn(3, 1, 4, 5) + kv = torch.randn(3, 6, 5) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, False], + [False, False, False, False, False, False], + [False, True, True, False, True, True], + ], + dtype=torch.bool, + ) + sink = torch.tensor([-1.0, 0.25, 1.5, -0.5]) + output = torch.empty(3, 4, 5) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, output) + expected = _golden_sink_attention(q, kv, valid_tokens, scale, sink) + + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + + +def test_lse_merge_with_sink_matches_concatenated_attention() -> None: + torch.manual_seed(2) + scale = 0.2 + q = torch.randn(4, 3, 7) + compressed_kv = torch.randn(4, 5, 7) + swa_kv = torch.randn(4, 3, 7) + compressed_kv[:, 1] = compressed_kv[:, 0] + swa_kv[:, 2] = compressed_kv[:, 0] + compressed_valid = torch.tensor( + [ + [True, True, False, True, False], + [False, False, False, False, False], + [True, False, True, True, False], + [False, False, False, False, False], + ], + dtype=torch.bool, + ) + swa_valid = torch.tensor( + [ + [True, False, True], + [True, True, False], + [False, False, False], + [False, False, False], + ], + dtype=torch.bool, + ) + sink = torch.tensor([-0.25, 0.75, 1.25]) + output = torch.empty(4, 3, 7) + comp_output, comp_lse = reference_attention_no_sink( + q, + compressed_kv, + compressed_valid, + scale, + ) + swa_output, swa_lse = reference_attention_no_sink(q, swa_kv, swa_valid, scale) + merge_reference_attention_with_sink( + subset_outputs=[comp_output, swa_output], + subset_lses=[comp_lse, swa_lse], + attn_sink=sink, + output=output, + ) + + expected = _golden_sink_attention( + q, + torch.cat([compressed_kv, swa_kv], dim=1), + torch.cat([compressed_valid, swa_valid], dim=1), + scale, + sink, + ) + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + assert torch.equal(output[3], torch.zeros_like(output[3])) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_lse_merge_with_sink_matches_reference() -> None: + torch.manual_seed(5) + comp_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + swa_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + comp_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + swa_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + comp_lse[1, 2] = float("-inf") + swa_lse[2, 1] = float("-inf") + sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda") + + output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16) + expected = torch.empty_like(output) + merge_two_sparse_mla_subsets_with_sink( + subset0_output=comp_output, + subset0_lse=comp_lse, + subset1_output=swa_output, + subset1_lse=swa_lse, + attn_sink=sink, + output=output, + ) + merge_reference_attention_with_sink( + subset_outputs=[comp_output, swa_output], + subset_lses=[comp_lse, swa_lse], + attn_sink=sink, + output=expected, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_single_lse_merge_with_sink_matches_reference() -> None: + torch.manual_seed(14) + subset_output = torch.randn(3, 4, 9, device="cuda", dtype=torch.float32) + subset_lse = torch.randn(3, 4, device="cuda", dtype=torch.float32) + subset_lse[1, 2] = float("-inf") + sink = torch.tensor([-0.5, 0.25, 1.0, -1.5], device="cuda") + + output = torch.empty(3, 4, 9, device="cuda", dtype=torch.bfloat16) + expected = torch.empty_like(output) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=sink, + output=expected, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_finish_with_sink_matches_finish_then_merge_reference() -> None: + torch.manual_seed(18) + max_score = torch.randn(4, 3, device="cuda", dtype=torch.float32) + denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + denom[1, 2] = 0.0 + max_score[1, 2] = float("-inf") + acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + sink = torch.tensor( + [-0.5, 0.25, 1.0, -float("inf"), -float("inf")], + device="cuda", + dtype=torch.float32, + ) + + output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, output) + + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=sink[:3], + output=expected, + ) + + torch.testing.assert_close( + output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + output[:, 3:].float(), + torch.full_like(output[:, 3:].float(), -7.0), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_finish_with_sink_returns_zero_when_no_tokens_or_sink() -> None: + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.full((2, 3, 17), float("nan"), device="cuda") + sink = torch.full((3,), float("-inf"), device="cuda") + + single_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink( + max_score, + denom, + acc, + sink, + output=single_output, + ) + torch.testing.assert_close( + single_output.float(), + torch.zeros_like(single_output.float()), + rtol=0, + atol=0, + ) + + two_output = torch.full((2, 3, 17), 7.0, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + max_score, + denom, + acc, + max_score, + denom, + acc, + sink, + output=two_output, + ) + torch.testing.assert_close( + two_output.float(), + torch.zeros_like(two_output.float()), + rtol=0, + atol=0, + ) + + +def test_triton_finish_two_states_with_sink_matches_finish_then_merge() -> None: + torch.manual_seed(22) + comp_max = torch.randn(4, 3, device="cuda", dtype=torch.float32) + comp_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + comp_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + swa_max = torch.randn(4, 3, device="cuda", dtype=torch.float32) + swa_denom = torch.rand(4, 3, device="cuda", dtype=torch.float32) + 0.1 + swa_acc = torch.randn(4, 3, 17, device="cuda", dtype=torch.float32) + sink = torch.tensor( + [-0.5, 0.25, 1.0, -float("inf"), -float("inf")], + device="cuda", + dtype=torch.float32, + ) + + comp_denom[0, 1] = 0.0 + comp_max[0, 1] = float("-inf") + swa_denom[2, 0] = 0.0 + swa_max[2, 0] = float("-inf") + comp_denom[3, 2] = 0.0 + comp_max[3, 2] = float("-inf") + swa_denom[3, 2] = 0.0 + swa_max[3, 2] = float("-inf") + + output = torch.full((4, 5, 17), -7.0, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + output, + ) + + comp_output = torch.empty_like(comp_acc) + comp_lse = torch.empty_like(comp_max) + swa_output = torch.empty_like(swa_acc) + swa_lse = torch.empty_like(swa_max) + finish_gathered_sparse_mla_attention( + comp_max, + comp_denom, + comp_acc, + comp_output, + comp_lse, + ) + finish_gathered_sparse_mla_attention( + swa_max, + swa_denom, + swa_acc, + swa_output, + swa_lse, + ) + expected = torch.empty(4, 3, 17, device="cuda", dtype=torch.bfloat16) + merge_two_sparse_mla_subsets_with_sink( + subset0_output=comp_output, + subset0_lse=comp_lse, + subset1_output=swa_output, + subset1_lse=swa_lse, + attn_sink=sink[:3], + output=expected, + ) + + torch.testing.assert_close( + output[:, :3].float(), expected.float(), rtol=1e-2, atol=1e-2 + ) + torch.testing.assert_close( + output[:, 3:].float(), + torch.full_like(output[:, 3:].float(), -7.0), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_dim", [16, 512]) +def test_triton_gathered_attention_chunk_matches_reference(head_dim: int) -> None: + torch.manual_seed(6) + scale = 0.125 + q = torch.randn(2, 1, 5, head_dim, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :3] + kv = torch.randn(2, 5, head_dim, device="cuda", dtype=torch.bfloat16) + slot_ids = torch.tensor( + [ + [0, 1, -1, 3, 4], + [5, -1, 7, 8, -1], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, head_dim), device="cuda") + + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv[:, :2], + slot_ids=slot_ids[:, :2], + lens=lens, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv[:, 2:], + slot_ids=slot_ids[:, 2:], + lens=lens, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q_active, + kv, + valid_tokens, + scale, + ) + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_gathered_attention_chunk_matches_reference_without_slot_ids() -> None: + torch.manual_seed(8) + scale = 0.2 + q = torch.randn(3, 1, 2, 32, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 6, 32, device="cuda", dtype=torch.bfloat16) + lens = torch.tensor([6, 3, 0], dtype=torch.int32, device="cuda") + max_score = torch.full((3, 2), float("-inf"), device="cuda") + denom = torch.zeros((3, 2), device="cuda") + acc = torch.zeros((3, 2, 32), device="cuda") + + accumulate_gathered_sparse_mla_attention_chunk( + q=q, + kv=kv, + slot_ids=None, + lens=lens, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(kv.shape[1], device="cuda") + valid_tokens = offsets[None, :] < lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q, + kv, + valid_tokens, + scale, + ) + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_dequantize_global_slots_k_cache_fp8_ds_mla_layout() -> None: + block_size = 4 + num_blocks = 2 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) for slot in (0, 3, 4) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 4], + [4, 0, 3, -1], + ], + dtype=torch.int32, + device="cuda", + ) + + output = torch.empty(2, 4, 512, dtype=torch.bfloat16, device="cuda") + dequantize_global_slots_k_cache(output, k_cache, slot_ids, block_size) + + expected = torch.zeros_like(output) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + expected[token_idx, topk_idx] = expected_by_slot[slot] + + torch.testing.assert_close(output.float(), expected.float(), rtol=0, atol=0) + + output_from_3d_indices = torch.empty_like(output) + dequantize_global_slots_k_cache( + output_from_3d_indices, + k_cache, + slot_ids.unsqueeze(1), + block_size, + ) + torch.testing.assert_close( + output_from_3d_indices.float(), + expected.float(), + rtol=0, + atol=0, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_dequantize_combined_sparse_mla_decode_kv_writes_direct_views() -> None: + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 2, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 3, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + for slot in (0, 3, 4): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for slot in (0, 1, 2, 3, 4): + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + compressed_slot_ids = torch.tensor( + [[0, 3, -1], [4, 0, 3]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([5, 7], dtype=torch.int32, device="cuda") + swa_lens = torch.tensor([2, 3], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[0, 1, 2], [2, 0, 1]], + dtype=torch.int32, + device="cuda", + ) + + combined = torch.full( + (2, 6, 512), + -7, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_combined_sparse_mla_decode_kv( + combined, + compressed_cache, + compressed_slot_ids, + compressed_block_size, + swa_cache, + seq_lens, + swa_lens, + block_table, + swa_block_size, + ) + + expected_comp = torch.empty(2, 3, 512, dtype=torch.bfloat16, device="cuda") + expected_swa = torch.full( + (2, 3, 512), + -7, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_global_slots_k_cache( + expected_comp, + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_and_gather_k_cache( + expected_swa, + swa_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + expected = torch.full_like(combined, -7) + expected[:, :3].copy_(expected_comp) + expected[:, 3:].copy_(expected_swa) + + torch.testing.assert_close(combined.float(), expected.float(), rtol=0, atol=0) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_global_slots_attention_chunk_matches_reference() -> None: + torch.manual_seed(10) + block_size = 4 + num_blocks = 3 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 1, 3, 4, 7, 8) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 8, 1], + [7, -1, 4, 0, 8], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + scale = 0.0625 + + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, :2], + lens=lens, + block_size=block_size, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, 2:], + lens=lens, + block_size=block_size, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + gathered[token_idx, topk_idx] = expected_by_slot[slot] + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_global_slots_multihead_attention_matches_reference( + head_block_size: int, +) -> None: + torch.manual_seed(19) + block_size = 4 + num_blocks = 3 + k_cache = torch.zeros( + num_blocks, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + expected_by_slot = { + slot: _write_fp8_ds_mla_token(k_cache, slot, block_size) + for slot in (0, 1, 3, 4, 7, 8) + } + slot_ids = torch.tensor( + [ + [0, 3, -1, 8, 1], + [7, -1, 4, 0, 8], + ], + dtype=torch.int32, + device="cuda", + ) + lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :5] + scale = 0.0625 + + max_score = torch.full((2, 5), float("-inf"), device="cuda") + denom = torch.zeros((2, 5), device="cuda") + acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, :2], + lens=lens, + block_size=block_size, + candidate_offset=0, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=head_block_size, + ) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + slot_ids=slot_ids[:, 2:], + lens=lens, + block_size=block_size, + candidate_offset=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=head_block_size, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + for token_idx in range(slot_ids.shape[0]): + for topk_idx in range(slot_ids.shape[1]): + slot = int(slot_ids[token_idx, topk_idx].item()) + if slot >= 0: + gathered[token_idx, topk_idx] = expected_by_slot[slot] + offsets = torch.arange(slot_ids.shape[1], device="cuda") + valid_tokens = (offsets[None, :] < lens[:, None]) & (slot_ids >= 0) + expected_output, expected_lse = reference_attention_no_sink( + q_active, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_paged_attention_chunk_matches_reference() -> None: + torch.manual_seed(12) + block_size = 4 + k_cache = torch.zeros( + 3, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [ + [1, 0, 2], + [2, 1, 0], + ], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([6, 9], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 4], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + scale = 0.0625 + + gathered = torch.zeros(2, 4, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + block_idx = pos // block_size + block_offset = pos % block_size + physical_block = int(block_table[token_idx, block_idx].item()) + slot = physical_block * block_size + block_offset + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[token_idx, gather_idx] = expected_by_slot[slot] + + max_score = torch.full((2, 3), float("-inf"), device="cuda") + denom = torch.zeros((2, 3), device="cuda") + acc = torch.zeros((2, 3, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=2, + num_candidates=2, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + + output = torch.empty_like(acc) + lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=output, + lse=lse, + ) + + offsets = torch.arange(gathered.shape[1], device="cuda") + valid_tokens = offsets[None, :] < gather_lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_paged_multihead_attention_matches_singlehead_and_reference( + head_block_size: int, +) -> None: + torch.manual_seed(23) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [ + [1, 0, 2, 3], + [2, 3, 1, 0], + ], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :, :5] + scale = 0.0625 + + gathered = torch.zeros(2, 5, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + block_idx = pos // block_size + block_offset = pos % block_size + physical_block = int(block_table[token_idx, block_idx].item()) + slot = physical_block * block_size + block_offset + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[token_idx, gather_idx] = expected_by_slot[slot] + + single_max = torch.full((2, 5), float("-inf"), device="cuda") + single_denom = torch.zeros((2, 5), device="cuda") + single_acc = torch.zeros((2, 5, 512), device="cuda") + multi_max = torch.full_like(single_max, float("-inf")) + multi_denom = torch.zeros_like(single_denom) + multi_acc = torch.zeros_like(single_acc) + + for candidate_offset, num_candidates in ((0, 2), (2, 3)): + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=candidate_offset, + num_candidates=num_candidates, + scale=scale, + max_score=single_max, + denom=single_denom, + acc=single_acc, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=candidate_offset, + num_candidates=num_candidates, + scale=scale, + max_score=multi_max, + denom=multi_denom, + acc=multi_acc, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(multi_max, single_max, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(multi_denom, single_denom, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(multi_acc, single_acc, rtol=2e-2, atol=2e-2) + + output = torch.empty_like(multi_acc) + lse = torch.empty_like(multi_max) + finish_gathered_sparse_mla_attention( + max_score=multi_max, + denom=multi_denom, + acc=multi_acc, + output=output, + lse=lse, + ) + offsets = torch.arange(gathered.shape[1], device="cuda") + valid_tokens = offsets[None, :] < gather_lens[:, None] + expected_output, expected_lse = reference_attention_no_sink( + q_active, + gathered, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=2e-2, atol=2e-2) + torch.testing.assert_close(lse, expected_lse, rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_fp8ds_paged_attention_with_sink_matches_reference() -> None: + torch.manual_seed(15) + block_size = 4 + k_cache = torch.zeros( + 3, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor([[1, 0, 2]], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([7], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([4], dtype=torch.int32, device="cuda") + q = torch.randn(1, 1, 3, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.tensor([-0.25, 0.5, 1.25], device="cuda") + scale = 0.0625 + + gathered = torch.zeros(1, 4, 512, device="cuda", dtype=torch.bfloat16) + expected_by_slot: dict[int, torch.Tensor] = {} + start_pos = int(seq_lens[0].item() - gather_lens[0].item()) + for gather_idx in range(int(gather_lens[0].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[0, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + expected_by_slot.setdefault( + slot, + _write_fp8_ds_mla_token(k_cache, slot, block_size), + ) + gathered[0, gather_idx] = expected_by_slot[slot] + + max_score = torch.full((1, 3), float("-inf"), device="cuda") + denom = torch.zeros((1, 3), device="cuda") + acc = torch.zeros((1, 3, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=4, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + + output = torch.empty(1, 3, 512, device="cuda", dtype=torch.bfloat16) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output, + ) + valid_tokens = torch.ones(1, 4, device="cuda", dtype=torch.bool) + expected = _golden_sink_attention(q, gathered, valid_tokens, scale, sink) + + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_paged_attention_with_sink_direct_matches_state_path( + head_block_size: int, +) -> None: + torch.manual_seed(29) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-0.5, 0.5, 5, device="cuda") + scale = 0.0625 + + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + _write_fp8_ds_mla_token(k_cache, slot, block_size) + + max_score = torch.full((2, 5), float("-inf"), device="cuda") + denom = torch.zeros((2, 5), device="cuda") + acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=1, + ) + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected) + + actual = torch.empty_like(expected) + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("head_block_size", [1, 2, 4]) +def test_triton_fp8ds_global_paged_attention_with_sink_direct_matches_state_path( + head_block_size: int, +) -> None: + torch.manual_seed(31) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 4, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 4, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + slot_ids = torch.tensor( + [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-1.0, 1.0, 5, device="cuda") + scale = 0.0625 + + for slot in (0, 1, 3, 4, 7, 8): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // swa_block_size].item()) + slot = physical_block * swa_block_size + pos % swa_block_size + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + comp_max = torch.full((2, 5), float("-inf"), device="cuda") + comp_denom = torch.zeros((2, 5), device="cuda") + comp_acc = torch.zeros((2, 5, 512), device="cuda") + swa_max = torch.full((2, 5), float("-inf"), device="cuda") + swa_denom = torch.zeros((2, 5), device="cuda") + swa_acc = torch.zeros((2, 5, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_cache, + slot_ids=slot_ids, + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=0, + scale=scale, + max_score=comp_max, + denom=comp_denom, + acc=comp_acc, + head_block_size=1, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=swa_block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=swa_max, + denom=swa_denom, + acc=swa_acc, + head_block_size=1, + ) + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + expected, + ) + + actual = torch.empty_like(expected) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=5, + num_swa_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=head_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_global_paged_decode_matches_dense_golden_long_offsets() -> None: + torch.manual_seed(71) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 8, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 8, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + compressed_slot_ids = torch.tensor( + [ + [0, 5, -1, 11, 17, 2], + [23, -1, 8, 0, 19, 31], + [4, -1, -1, 6, 7, 12], + ], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([5, 6, 0], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([17, 30, 7], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([5, 6, 4], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [ + [4, 0, 6, 1, 3, 5, 2, 7], + [7, 2, 4, 0, 6, 1, 5, 3], + [1, 3, 0, 2, 4, 5, 6, 7], + ], + dtype=torch.int32, + device="cuda", + ) + q = torch.randn(3, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + active_heads = 5 + sink = torch.linspace(-0.75, 0.75, active_heads, device="cuda") + scale = 0.0625 + + compressed_kv = _materialize_global_fp8_ds_mla_slots( + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_kv = _materialize_paged_fp8_ds_mla_window( + swa_cache, + seq_lens, + gather_lens, + block_table, + swa_block_size, + ) + compressed_offsets = torch.arange( + compressed_slot_ids.shape[1], + device="cuda", + dtype=torch.int32, + ) + swa_offsets = torch.arange(swa_kv.shape[1], device="cuda", dtype=torch.int32) + compressed_valid = (compressed_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + swa_valid = swa_offsets[None, :] < gather_lens[:, None] + expected = _golden_sink_attention( + q[:, 0, :active_heads], + torch.cat([compressed_kv, swa_kv], dim=1), + torch.cat([compressed_valid, swa_valid], dim=1), + scale, + sink, + ).to(torch.bfloat16) + + actual = torch.empty(3, active_heads, 512, device="cuda", dtype=torch.bfloat16) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=compressed_slot_ids.shape[1], + num_swa_candidates=swa_kv.shape[1], + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=2, + num_heads=active_heads, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_matmul_sparse_mla_attention_with_sink_matches_reference() -> None: + torch.manual_seed(41) + q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True], + [False, True, True, False, True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 5, device="cuda") + scale = 0.0625 + + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention( + q, + kv, + valid_tokens, + scale, + sink, + expected, + ) + + actual = torch.empty_like(expected) + matmul_sparse_mla_attention_with_sink( + q, + kv, + valid_tokens, + scale, + sink, + actual, + num_heads=5, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_mtp_matmul_global_slots_decode_matches_dense_golden_long_offsets() -> None: + torch.manual_seed(73) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 8, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 8, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + compressed_slot_ids = torch.tensor( + [ + [0, 5, -1, 11, 17, 2], + [23, -1, 8, 0, 19, 31], + [4, -1, -1, 6, 7, 12], + ], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([5, 6, 0], dtype=torch.int32, device="cuda") + swa_slot_ids = torch.tensor( + [ + [3, 4, 13, 14, 15], + [28, 29, 30, 31, -1], + [7, 6, 5, -1, -1], + ], + dtype=torch.int32, + device="cuda", + ) + swa_lens = torch.tensor([5, 4, 3], dtype=torch.int32, device="cuda") + q = torch.randn(3, 1, 8, 512, device="cuda", dtype=torch.bfloat16) + active_heads = 5 + sink = torch.linspace(-0.5, 0.5, active_heads, device="cuda") + scale = 0.0625 + + compressed_kv = _materialize_global_fp8_ds_mla_slots( + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_kv = _materialize_global_fp8_ds_mla_slots( + swa_cache, + swa_slot_ids, + swa_block_size, + ) + combined_kv = torch.empty( + 3, + compressed_slot_ids.shape[1] + swa_slot_ids.shape[1], + 512, + dtype=torch.bfloat16, + device="cuda", + ) + dequantize_global_slots_k_cache( + combined_kv[:, : compressed_slot_ids.shape[1]], + compressed_cache, + compressed_slot_ids, + compressed_block_size, + ) + dequantize_global_slots_k_cache( + combined_kv[:, compressed_slot_ids.shape[1] :], + swa_cache, + swa_slot_ids, + swa_block_size, + ) + valid_tokens = torch.empty( + combined_kv.shape[:2], + dtype=torch.bool, + device="cuda", + ) + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + + compressed_offsets = torch.arange( + compressed_slot_ids.shape[1], + device="cuda", + dtype=torch.int32, + ) + swa_offsets = torch.arange(swa_slot_ids.shape[1], device="cuda", dtype=torch.int32) + compressed_valid = (compressed_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + swa_valid = swa_offsets[None, :] < swa_lens[:, None] + expected_kv = torch.cat([compressed_kv, swa_kv], dim=1) + expected_valid = torch.cat([compressed_valid, swa_valid], dim=1) + expected = _golden_sink_attention( + q[:, 0, :active_heads], + expected_kv, + expected_valid, + scale, + sink, + ).to(torch.bfloat16) + + torch.testing.assert_close(combined_kv.float(), expected_kv.float(), rtol=0, atol=0) + torch.testing.assert_close(valid_tokens, expected_valid) + + actual = torch.empty_like(expected) + matmul_sparse_mla_attention_with_sink( + q, + combined_kv, + valid_tokens, + scale, + sink, + actual, + num_heads=active_heads, + value_block_size=512, + candidate_block_size=128, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_matmul_sparse_mla_attention_accepts_bf16_score_buffer() -> None: + torch.manual_seed(67) + q = torch.randn(2, 1, 5, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(2, 7, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True], + [False, True, True, False, True, False, False], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 5, device="cuda") + scale = 0.0625 + + expected = torch.empty(2, 5, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected) + + actual = torch.empty_like(expected) + score_buffer = torch.empty(2, 5, 7, device="cuda", dtype=torch.bfloat16) + matmul_sparse_mla_attention_with_sink( + q, + kv, + valid_tokens, + scale, + sink, + actual, + num_heads=5, + score_buffer=score_buffer, + value_block_size=512, + candidate_block_size=128, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize( + ("candidate_block_size", "value_block_size"), + [(32, 128), (64, 128), (64, 256), (128, 512)], +) +def test_finish_materialized_scores_candidate_block_matches_reference( + candidate_block_size: int, + value_block_size: int, +) -> None: + torch.manual_seed(61) + q = torch.randn(3, 1, 7, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 13, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [ + True, + True, + False, + True, + False, + True, + True, + False, + True, + True, + False, + True, + False, + ], + [ + False, + True, + True, + False, + True, + False, + False, + True, + False, + True, + True, + False, + True, + ], + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 7, device="cuda") + scale = 0.0625 + + expected = torch.empty(3, 7, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected) + + scores = torch.bmm(q[:, 0].float(), kv.float().transpose(1, 2)) + scores.mul_(scale) + actual = torch.empty_like(expected) + finish_materialized_sparse_mla_scores_with_sink( + scores, + kv, + valid_tokens, + sink, + actual, + num_heads=7, + value_block_size=value_block_size, + candidate_block_size=candidate_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("value_block_size", [128, 256]) +def test_finish_materialized_scores_value_block_matches_reference( + value_block_size: int, +) -> None: + torch.manual_seed(53) + q = torch.randn(3, 1, 7, 512, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(3, 11, 512, device="cuda", dtype=torch.bfloat16) + valid_tokens = torch.tensor( + [ + [True, True, False, True, False, True, True, False, True, True, False], + [False, True, True, False, True, False, False, True, False, True, True], + [ + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + False, + ], + ], + dtype=torch.bool, + device="cuda", + ) + sink = torch.linspace(-0.25, 0.25, 7, device="cuda") + scale = 0.0625 + + expected = torch.empty(3, 7, 512, device="cuda", dtype=torch.bfloat16) + sink_aware_reference_attention(q, kv, valid_tokens, scale, sink, expected) + + scores = torch.bmm( + q[:, 0].float(), + kv.float().transpose(1, 2), + ) + scores.mul_(scale) + actual = torch.empty_like(expected) + finish_materialized_sparse_mla_scores_with_sink( + scores, + kv, + valid_tokens, + sink, + actual, + num_heads=7, + value_block_size=value_block_size, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_build_combined_sparse_mla_decode_valid_mask_matches_torch() -> None: + compressed_slot_ids = torch.tensor( + [ + [7, 4, -1, 9, 11], + [2, -1, 3, 8, 10], + [-1, -1, -1, -1, -1], + ], + device="cuda", + dtype=torch.int32, + ) + topk_lens = torch.tensor([4, 3, 0], device="cuda", dtype=torch.int32) + swa_lens = torch.tensor([3, 1, 0], device="cuda", dtype=torch.int32) + valid_tokens = torch.empty(3, 9, device="cuda", dtype=torch.bool) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + + comp_offsets = torch.arange(5, device="cuda", dtype=torch.int32) + swa_offsets = torch.arange(4, device="cuda", dtype=torch.int32) + expected = torch.empty_like(valid_tokens) + expected[:, :5] = (comp_offsets[None, :] < topk_lens[:, None]) & ( + compressed_slot_ids >= 0 + ) + expected[:, 5:] = swa_offsets[None, :] < swa_lens[:, None] + + torch.testing.assert_close(valid_tokens, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_heads", [8, 16, 32, 64]) +def test_triton_fp8ds_paged_attention_with_sink_supports_tp_local_heads( + num_heads: int, +) -> None: + torch.manual_seed(37 + num_heads) + block_size = 4 + k_cache = torch.zeros( + 4, + block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-0.5, 0.5, num_heads, device="cuda") + scale = 0.0625 + + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // block_size].item()) + slot = physical_block * block_size + pos % block_size + _write_fp8_ds_mla_token(k_cache, slot, block_size) + + max_score = torch.full((2, num_heads), float("-inf"), device="cuda") + denom = torch.zeros((2, num_heads), device="cuda") + acc = torch.zeros((2, num_heads, 512), device="cuda") + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + head_block_size=1, + ) + expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16) + finish_sparse_mla_attention_with_sink(max_score, denom, acc, sink, expected) + + actual = torch.empty_like(expected) + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=k_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=4, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +@pytest.mark.parametrize("num_heads", [8, 16, 32, 64]) +def test_triton_fp8ds_global_paged_attention_with_sink_supports_tp_local_heads( + num_heads: int, +) -> None: + torch.manual_seed(41 + num_heads) + compressed_block_size = 4 + swa_block_size = 4 + compressed_cache = torch.zeros( + 4, + compressed_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + swa_cache = torch.zeros( + 4, + swa_block_size, + _TOKEN_DATA_SIZE + _SCALE_DIM, + dtype=torch.uint8, + device="cuda", + ) + slot_ids = torch.tensor( + [[0, 3, -1, 8, 1], [7, -1, 4, 0, 8]], + dtype=torch.int32, + device="cuda", + ) + topk_lens = torch.tensor([4, 5], dtype=torch.int32, device="cuda") + block_table = torch.tensor( + [[1, 0, 2, 3], [2, 3, 1, 0]], + dtype=torch.int32, + device="cuda", + ) + seq_lens = torch.tensor([7, 11], dtype=torch.int32, device="cuda") + gather_lens = torch.tensor([3, 5], dtype=torch.int32, device="cuda") + q = torch.randn(2, 1, num_heads, 512, device="cuda", dtype=torch.bfloat16) + sink = torch.linspace(-1.0, 1.0, num_heads, device="cuda") + scale = 0.0625 + + for slot in (0, 1, 3, 4, 7, 8): + _write_fp8_ds_mla_token(compressed_cache, slot, compressed_block_size) + for token_idx in range(seq_lens.shape[0]): + start_pos = int(seq_lens[token_idx].item() - gather_lens[token_idx].item()) + for gather_idx in range(int(gather_lens[token_idx].item())): + pos = start_pos + gather_idx + physical_block = int(block_table[token_idx, pos // swa_block_size].item()) + slot = physical_block * swa_block_size + pos % swa_block_size + _write_fp8_ds_mla_token(swa_cache, slot, swa_block_size) + + comp_max = torch.full((2, num_heads), float("-inf"), device="cuda") + comp_denom = torch.zeros((2, num_heads), device="cuda") + comp_acc = torch.zeros((2, num_heads, 512), device="cuda") + swa_max = torch.full((2, num_heads), float("-inf"), device="cuda") + swa_denom = torch.zeros((2, num_heads), device="cuda") + swa_acc = torch.zeros((2, num_heads, 512), device="cuda") + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_cache, + slot_ids=slot_ids, + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=0, + scale=scale, + max_score=comp_max, + denom=comp_denom, + acc=comp_acc, + head_block_size=1, + ) + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + block_size=swa_block_size, + candidate_offset=0, + num_candidates=5, + scale=scale, + max_score=swa_max, + denom=swa_denom, + acc=swa_acc, + head_block_size=1, + ) + expected = torch.empty(2, num_heads, 512, device="cuda", dtype=torch.bfloat16) + finish_two_sparse_mla_attention_states_with_sink( + comp_max, + comp_denom, + comp_acc, + swa_max, + swa_denom, + swa_acc, + sink, + expected, + ) + + actual = torch.empty_like(expected) + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_cache, + slot_ids=slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_cache, + seq_lens=seq_lens, + gather_lens=gather_lens, + block_table=block_table, + swa_block_size=swa_block_size, + num_compressed_candidates=5, + num_swa_candidates=5, + scale=scale, + attn_sink=sink, + output=actual, + head_block_size=4, + ) + + torch.testing.assert_close(actual.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA only") +def test_triton_indexed_bf16_prefill_chunks_match_reference() -> None: + torch.manual_seed(17) + q = torch.randn(5, 5, 16, device="cuda", dtype=torch.bfloat16) + q_active = q[:, :3] + kv = torch.randn(2, 7, 16, device="cuda", dtype=torch.bfloat16) + kv_flat = kv.reshape(-1, q.shape[-1]) + combined_indices = torch.tensor( + [ + [0, 3, -1, 5, 3, 1], + [4, -1, 2, 2, 1, 8], + [-1, -1, -1, -1, -1, -1], + [8, 0, 9, -1, 7, 4], + [13, 12, 0, 12, -1, 3], + ], + dtype=torch.int64, + device="cuda", + ) + combined_lens = torch.tensor([5, 4, 0, 6, 5], dtype=torch.int32, device="cuda") + sink = torch.tensor([-0.5, 1.0, 0.25], dtype=torch.float32, device="cuda") + scale = 0.375 + output = torch.empty_like(q_active) + + for token_start in (0, 2, 4): + token_end = min(token_start + 2, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + max_score = torch.full( + (q_chunk.shape[0], q_active.shape[1]), + float("-inf"), + device="cuda", + ) + denom = torch.zeros_like(max_score) + acc = torch.zeros( + q_chunk.shape[0], + q_active.shape[1], + q_chunk.shape[-1], + device="cuda", + dtype=torch.float32, + ) + for index_start in (0, 3): + index_end = min(index_start + 3, combined_indices.shape[-1]) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=scale, + max_score=max_score, + denom=denom, + acc=acc, + ) + subset_output = torch.empty_like(acc) + subset_lse = torch.empty_like(max_score) + finish_gathered_sparse_mla_attention( + max_score=max_score, + denom=denom, + acc=acc, + output=subset_output, + lse=subset_lse, + ) + merge_sparse_mla_subset_with_sink( + subset_output=subset_output, + subset_lse=subset_lse, + attn_sink=sink, + output=output[token_start:token_end], + ) + + expected = torch.empty_like(q_active) + reference_sparse_mla_prefill( + q=q_active, + kv=kv, + combined_indices=combined_indices, + combined_lens=combined_lens, + scale=scale, + attn_sink=sink, + output=expected, + topk_chunk_size=3, + query_chunk_size=2, + ) + + torch.testing.assert_close(output.float(), expected.float(), rtol=2e-2, atol=2e-2) + + +@pytest.mark.parametrize( + ("topk_chunk_size", "query_chunk_size"), + [(1, 1), (2, 3), (5, 2)], +) +def test_reference_sparse_mla_prefill_matches_dense_golden( + topk_chunk_size: int, + query_chunk_size: int, +) -> None: + torch.manual_seed(4) + scale = 0.375 + q = torch.randn(4, 2, 3) + kv = torch.randn(2, 5, 3) + combined_indices = torch.tensor( + [ + [0, 3, -1, 5, 3], + [4, -1, 2, 2, 1], + [-1, -1, -1, -1, -1], + [8, 0, 9, -1, 7], + ], + dtype=torch.int64, + ) + combined_lens = torch.tensor([4, 3, 0, 5], dtype=torch.int32) + sink = torch.tensor([-0.5, 1.0]) + output = torch.empty_like(q) + + reference_sparse_mla_prefill( + q=q, + kv=kv, + combined_indices=combined_indices, + combined_lens=combined_lens, + scale=scale, + attn_sink=sink, + output=output, + topk_chunk_size=topk_chunk_size, + query_chunk_size=query_chunk_size, + ) + + kv_flat = kv.reshape(-1, q.shape[-1]) + offsets = torch.arange(combined_indices.shape[-1]) + valid_tokens = (offsets[None, :] < combined_lens[:, None]) & (combined_indices >= 0) + safe_indices = torch.where( + valid_tokens, + combined_indices, + torch.zeros((), dtype=combined_indices.dtype), + ).long() + gathered_kv = kv_flat[safe_indices] + expected = _golden_sink_attention(q, gathered_kv, valid_tokens, scale, sink) + + torch.testing.assert_close(output, expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("chunk_size", [1, 2, 5]) +def test_chunked_reference_accumulation_matches_one_shot(chunk_size: int) -> None: + torch.manual_seed(3) + scale = 0.3 + q = torch.randn(3, 2, 4) + kv = torch.randn(3, 9, 4) + valid_tokens = torch.tensor( + [ + [True, False, True, True, False, False, True, False, True], + [False, False, False, False, False, False, False, False, False], + [True, True, True, False, True, False, True, True, False], + ], + dtype=torch.bool, + ) + output, lse = _chunked_no_sink_attention( + q, + kv, + valid_tokens, + scale, + chunk_size, + ) + expected_output, expected_lse = _golden_no_sink_attention( + q, + kv, + valid_tokens, + scale, + ) + + torch.testing.assert_close(output, expected_output, rtol=1e-6, atol=1e-6) + torch.testing.assert_close(lse, expected_lse, rtol=1e-6, atol=1e-6) + + +def test_triton_sparse_mla_path_allows_cudagraph_support_by_default( + monkeypatch, +) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False) + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is ( + AttentionCGSupport.UNIFORM_BATCH + ) + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is ( + AttentionCGSupport.UNIFORM_BATCH + ) + + vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ) + ) + disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE + assert vllm_config.compilation_config.compile_sizes == [1, 2] + assert vllm_config.compilation_config.compile_ranges_endpoints == [8192] + assert ( + vllm_config.compilation_config.cudagraph_mode + == CUDAGraphMode.FULL_AND_PIECEWISE + ) + assert vllm_config.compilation_config.cudagraph_capture_sizes == [1, 2, 4] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 4 + + +def test_triton_sparse_mla_path_can_disable_cudagraphs(monkeypatch) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "0") + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + + assert FlashMLASparseMetadataBuilder.get_cudagraph_support(None, mla_spec) is ( + AttentionCGSupport.NEVER + ) + assert DeepseekSparseSWAMetadataBuilder.get_cudagraph_support(None, swa_spec) is ( + AttentionCGSupport.NEVER + ) + + vllm_config = SimpleNamespace( + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ) + ) + disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.NONE + assert vllm_config.compilation_config.compile_sizes == [] + assert vllm_config.compilation_config.compile_ranges_endpoints == [] + assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert vllm_config.compilation_config.cudagraph_capture_sizes == [] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 0 + + +def test_triton_sparse_mla_path_disables_cudagraphs_for_mtp( + monkeypatch, +) -> None: + monkeypatch.setenv("VLLM_TRITON_MLA_SPARSE", "1") + monkeypatch.delenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", raising=False) + + mla_spec = MLAAttentionSpec( + block_size=256, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + alignment=576, + compress_ratio=4, + model_version="deepseek_v4", + ) + swa_spec = SlidingWindowMLASpec( + block_size=64, + num_kv_heads=1, + head_size=512, + dtype=torch.uint8, + sliding_window=128, + cache_dtype_str="fp8_ds_mla", + alignment=576, + model_version="deepseek_v4", + ) + vllm_config = SimpleNamespace( + speculative_config=SimpleNamespace( + method="mtp", + num_speculative_tokens=2, + ), + compilation_config=SimpleNamespace( + mode=CompilationMode.VLLM_COMPILE, + compile_sizes=[1, 2], + compile_ranges_endpoints=[8192], + cudagraph_mode=CUDAGraphMode.FULL_AND_PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4], + max_cudagraph_capture_size=4, + ), + ) + + assert ( + FlashMLASparseMetadataBuilder.get_cudagraph_support( + vllm_config, + mla_spec, + ) + is AttentionCGSupport.NEVER + ) + assert ( + DeepseekSparseSWAMetadataBuilder.get_cudagraph_support( + vllm_config, + swa_spec, + ) + is AttentionCGSupport.NEVER + ) + + disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) + + assert vllm_config.compilation_config.mode == CompilationMode.NONE + assert vllm_config.compilation_config.compile_sizes == [] + assert vllm_config.compilation_config.compile_ranges_endpoints == [] + assert vllm_config.compilation_config.cudagraph_mode == CUDAGraphMode.NONE + assert vllm_config.compilation_config.cudagraph_capture_sizes == [] + assert vllm_config.compilation_config.max_cudagraph_capture_size == 0 diff --git a/tests/v1/attention/test_sm120_deepgemm_fallbacks.py b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py new file mode 100644 index 000000000000..3fae48146f2d --- /dev/null +++ b/tests/v1/attention/test_sm120_deepgemm_fallbacks.py @@ -0,0 +1,245 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +import vllm.utils.deep_gemm as deep_gemm_utils +from vllm.model_executor.layers.sparse_attn_indexer import ( + _decode_logits_width, + _decode_topk_logits_width, + _sparse_indexer_requires_deep_gemm, +) +from vllm.platforms import current_platform +from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.mla import indexer as mla_indexer + + +def test_decode_logits_width_uses_active_context_bound(): + assert _decode_logits_width(262144, 1024) == 1024 + assert _decode_logits_width(4096, 8192) == 4096 + assert _decode_logits_width(4096, 0) == 4096 + assert _decode_logits_width(0, 1024) == 0 + + +def test_decode_topk_logits_width_keeps_topk_kernel_width(): + assert _decode_topk_logits_width(262144, 1024, 512) == 1024 + assert _decode_topk_logits_width(262144, 128, 512) == 512 + assert _decode_topk_logits_width(300, 128, 512) == 300 + assert _decode_topk_logits_width(0, 128, 512) == 0 + + +def test_sm120_sparse_indexer_does_not_require_deep_gemm(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + + assert _sparse_indexer_requires_deep_gemm() is False + + +def test_non_sm120_cuda_sparse_indexer_still_requires_deep_gemm(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: False, + ) + + assert _sparse_indexer_requires_deep_gemm() is True + + +def test_sm120_paged_mqa_metadata_uses_backend_impl(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + lazy_init_calls = [] + monkeypatch.setattr( + deep_gemm_utils, "_lazy_init", lambda: lazy_init_calls.append(1) + ) + expected = torch.tensor([[1, 2]], dtype=torch.int32) + + def fake_deep_gemm_metadata(context_lens, block_size, num_sms): + assert context_lens.shape == (2, 1) + assert block_size == 256 + assert num_sms == 4 + return expected + + monkeypatch.setattr( + deep_gemm_utils, + "_get_paged_mqa_logits_metadata_impl", + fake_deep_gemm_metadata, + ) + context_lens = torch.tensor([[1], [3]], dtype=torch.int32) + + metadata = deep_gemm_utils.get_paged_mqa_logits_metadata( + context_lens, block_size=256, num_sms=4 + ) + + assert metadata is expected + assert lazy_init_calls == [1] + + +def test_sm120_mla_indexer_skips_deep_gemm_scheduler_metadata(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True) + + assert not mla_indexer._uses_deep_gemm_scheduler_metadata() + + +def test_cuda_mla_indexer_uses_deep_gemm_scheduler_metadata_off_sm12x(monkeypatch): + monkeypatch.setattr(current_platform, "is_cuda", lambda: True) + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: False, + ) + monkeypatch.setattr(mla_indexer, "has_deep_gemm", lambda: True) + + assert mla_indexer._uses_deep_gemm_scheduler_metadata() + + +def test_sm120_fp8_mqa_fallbacks_do_not_initialize_deep_gemm(monkeypatch): + monkeypatch.setattr( + current_platform, + "is_device_capability_family", + lambda capability: capability == 120, + ) + + def fail_lazy_init(): + raise AssertionError("SM120 FP8 MQA should not initialize DeepGEMM") + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", fail_lazy_init) + + mqa_result = torch.empty(1) + paged_result = torch.empty(1) + calls = [] + + def fake_mqa_fallback(*args, **kwargs): + calls.append("mqa") + return mqa_result + + def fake_paged_fallback(*args, **kwargs): + calls.append("paged") + return paged_result + + monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_sm12x", fake_mqa_fallback) + monkeypatch.setattr( + deep_gemm_utils, "_fp8_paged_mqa_logits_sm12x", fake_paged_fallback + ) + + assert ( + deep_gemm_utils.fp8_fp4_mqa_logits( + (torch.empty(1, 1, 1), None), + (torch.empty(1, 1), torch.empty(1)), + torch.empty(1, 1), + torch.empty(1, dtype=torch.int32), + torch.empty(1, dtype=torch.int32), + clean_logits=False, + ) + is mqa_result + ) + assert ( + deep_gemm_utils.fp8_fp4_paged_mqa_logits( + (torch.empty(1, 1, 1, 1), None), + torch.empty(1, 1, 1, 5, dtype=torch.uint8), + torch.empty(1, 1), + torch.empty(1, 1, dtype=torch.int32), + torch.empty(1, 1, dtype=torch.int32), + torch.empty(1, dtype=torch.int32), + max_model_len=1, + clean_logits=False, + ) + is paged_result + ) + assert calls == ["mqa", "paged"] + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_paged_mqa_direct_topk_matches_truncated_decode_width( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(7) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 64, 16 + active_max_len = 13 + topk_tokens = 6 + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", 7) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = torch.empty( + num_blocks, + block_size, + 1, + head_dim + 4, + device="cuda", + dtype=torch.uint8, + ) + fused_kv[..., :head_dim] = kv_fp8.view(torch.uint8) + fused_kv[..., head_dim:] = kv_scale.contiguous().view(torch.uint8) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor( + [[5, active_max_len], [9, 12]], device="cuda", dtype=torch.int32 + ) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + full_width_topk = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + truncated_width_topk = torch.empty_like(full_width_topk) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + full_width_topk, + ) + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8, None), + fused_kv, + weights, + context_lens, + block_tables, + active_max_len, + truncated_width_topk, + ) + + torch.testing.assert_close(truncated_width_topk, full_width_topk, rtol=0, atol=0) diff --git a/tests/v1/attention/test_sparse_attn_indexer.py b/tests/v1/attention/test_sparse_attn_indexer.py new file mode 100644 index 000000000000..eb09cf058d8f --- /dev/null +++ b/tests/v1/attention/test_sparse_attn_indexer.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.model_executor.layers.sparse_attn_indexer import ( + SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, + SM120_SHORT_ROW_TOPK_MAX_WIDTH, + _should_use_sm120_short_row_topk_decode, +) + + +@pytest.mark.parametrize( + ("topk_tokens", "logits_width", "num_rows", "is_cuda_sm120", "expected"), + [ + (512, SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH, 32, True, True), + (512, 8192, 16, True, True), + (512, 8192, 32, True, True), + (512, 12288, 32, True, False), + (512, SM120_SHORT_ROW_TOPK_MAX_WIDTH, 1, True, False), + (512, 4096, 1, False, False), + (2048, 4096, 1, True, False), + ], +) +def test_sm120_short_row_topk_decode_selector( + topk_tokens: int, + logits_width: int, + num_rows: int, + is_cuda_sm120: bool, + expected: bool, +) -> None: + assert ( + _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits_width, + num_rows, + is_cuda_sm120, + ) + is expected + ) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py index 22acc748d24b..8becd568b680 100644 --- a/tests/v1/attention/test_sparse_mla_backends.py +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -8,6 +8,7 @@ import pytest import torch +import vllm.utils.deep_gemm as deep_gemm_utils from tests.v1.attention.test_mla_backends import ( BATCH_SPECS, BatchSpec, @@ -42,9 +43,16 @@ FlashMLASparseBackend, triton_convert_req_index_to_global_index, ) -from vllm.v1.attention.backends.mla.indexer import split_indexer_prefill_chunks +from vllm.v1.attention.backends.mla.indexer import ( + sparse_indexer_max_logits_bytes, + split_indexer_prefill_chunks, +) from vllm.v1.attention.backends.utils import split_prefill_chunks from vllm.v1.attention.ops import flashmla +from vllm.v1.attention.ops.deepseek_v4_ops import ( + combine_topk_swa_indices, + compute_global_topk_indices_and_lens, +) SPARSE_BACKEND_BATCH_SPECS = { name: BATCH_SPECS[name] @@ -67,6 +75,487 @@ DEVICE_TYPE = current_platform.device_type +def _make_packed_fp8_indexer_cache( + kv_fp8: torch.Tensor, + kv_scale: torch.Tensor, +) -> torch.Tensor: + num_blocks, block_size, num_kv_heads, head_dim = kv_fp8.shape + assert num_kv_heads == 1 + kv_scale_bytes = kv_scale.contiguous().view(torch.uint8).reshape( + num_blocks, block_size, num_kv_heads, -1 + ) + scale_bytes = kv_scale_bytes.shape[-1] + fused_kv = torch.empty( + num_blocks, + block_size, + head_dim + scale_bytes, + device=kv_fp8.device, + dtype=torch.uint8, + ) + fused_kv_blocks = fused_kv.view(num_blocks, -1) + value_end = block_size * head_dim + scale_end = value_end + block_size * scale_bytes + fused_kv_blocks[:, :value_end] = kv_fp8.view(torch.uint8).reshape( + num_blocks, -1 + ) + fused_kv_blocks[:, value_end:scale_end] = kv_scale_bytes.reshape( + num_blocks, -1 + ) + return fused_kv + + +def test_sm120_fp8_mqa_logits_chunk_sizes_cap_large_scores(): + assert deep_gemm_utils._fp8_mqa_logits_head_chunk_size(128, 128, 32) == 8 + assert deep_gemm_utils._fp8_mqa_logits_head_chunk_size(8192, 8192, 32) == 1 + assert deep_gemm_utils._fp8_mqa_logits_k_chunk_size(128, 128, 8) == 128 + assert deep_gemm_utils._fp8_mqa_logits_k_chunk_size(8192, 8192, 1) == 2048 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_tf32_hc_prenorm_gemm_fallback_matches_split_abi( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(0) + num_tokens, out_features, hidden_size = 7, 12, 64 + x = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) + fn = torch.randn(out_features, hidden_size, device="cuda", dtype=torch.float32) + + out = torch.empty(num_tokens, out_features, device="cuda", dtype=torch.float32) + sqrsum = torch.empty(num_tokens, device="cuda", dtype=torch.float32) + deep_gemm_utils._tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split=1) + + expected_out = x.float() @ fn.T + expected_sqrsum = x.float().square().sum(dim=-1) + torch.testing.assert_close(out, expected_out, rtol=0, atol=0) + torch.testing.assert_close(sqrsum, expected_sqrsum, rtol=0, atol=0) + + split_out = torch.empty(3, num_tokens, out_features, device="cuda") + split_sqrsum = torch.empty(3, num_tokens, device="cuda") + deep_gemm_utils._tf32_hc_prenorm_gemm_torch( + x, fn, split_out, split_sqrsum, num_split=3 + ) + torch.testing.assert_close(split_out.sum(dim=0), expected_out, rtol=0, atol=0) + torch.testing.assert_close(split_sqrsum.sum(dim=0), expected_sqrsum, rtol=0, atol=0) + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_tf32_hc_prenorm_gemm_impl", None) + wrapper_out = torch.empty_like(split_out) + wrapper_sqrsum = torch.empty_like(split_sqrsum) + deep_gemm_utils.tf32_hc_prenorm_gemm( + x, fn, wrapper_out, wrapper_sqrsum, num_split=3 + ) + torch.testing.assert_close( + wrapper_out.sum(dim=0), expected_out, rtol=2e-2, atol=2e-2 + ) + torch.testing.assert_close( + wrapper_sqrsum.sum(dim=0), expected_sqrsum, rtol=1e-4, atol=1e-4 + ) + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_paged_mqa_logits_fallback_matches_reference( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(1) + batch_size, next_n, num_heads, head_dim = 2, 2, 4, 32 + block_size, max_model_len, num_blocks = 4, 12, 4 + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor([[3, 6], [7, 11]], device="cuda", dtype=torch.int32) + block_tables = torch.tensor( + [[0, 1, 2], [1, 2, 3]], device="cuda", dtype=torch.int32 + ) + expected = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device="cuda", + dtype=torch.float32, + ) + kv_dequant = kv_fp8.float() * kv_scale + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + for token_idx in range(int(context_lens[batch_idx, next_idx].item())): + block = int(block_tables[batch_idx, token_idx // block_size].item()) + offset = token_idx % block_size + score = ( + q_fp8[batch_idx, next_idx].float() * kv_dequant[block, offset, 0] + ).sum(dim=1) + expected[row, token_idx] = (score.relu() * weights[row]).sum() + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_fp8_fp4_paged_mqa_logits_impl", None) + + def fail_torch_path(*args, **kwargs): + raise AssertionError("torch paged fallback should not be used") + + monkeypatch.setattr(deep_gemm_utils, "_fp8_paged_mqa_logits_torch", fail_torch_path) + actual = deep_gemm_utils.fp8_fp4_paged_mqa_logits( + (q_fp8.contiguous(), None), + fused_kv, + weights, + context_lens, + block_tables, + schedule_metadata=torch.empty(0, device="cuda", dtype=torch.int32), + max_model_len=max_model_len, + clean_logits=False, + ) + torch.testing.assert_close(actual, expected, rtol=0, atol=1e-5) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_triton, + ) + + triton_actual = fp8_paged_mqa_logits_triton( + q_fp8.contiguous(), fused_kv, weights, context_lens, block_tables, max_model_len + ) + assert torch.equal(torch.isneginf(triton_actual), torch.isneginf(expected)) + finite = torch.isfinite(expected) + assert (triton_actual[finite] - expected[finite]).abs().max() < 2e-2 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_paged_mqa_rowwise_logits_matches_reference(): + torch.manual_seed(11) + batch_size, next_n, num_heads, head_dim = 2, 1, 8, 64 + block_size, max_model_len, num_blocks = 4, 18, 8 + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn).contiguous() + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor([[7], [17]], device="cuda", dtype=torch.int32) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_rowwise_triton, + ) + + actual = fp8_paged_mqa_logits_rowwise_triton( + q_fp8, fused_kv, weights, context_lens, block_tables, max_model_len + ) + expected = deep_gemm_utils._fp8_paged_mqa_logits_torch( + (q_fp8, None), fused_kv, weights, context_lens, block_tables, max_model_len + ) + + assert torch.equal(torch.isneginf(actual), torch.isneginf(expected)) + finite = torch.isfinite(expected) + assert (actual[finite] - expected[finite]).abs().max() < 2e-2 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_paged_mqa_topk_indices_streams_chunks( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(3) + batch_size, next_n, num_heads, head_dim = 2, 2, 8, 32 + block_size, max_model_len, num_blocks = 4, 20, 8 + topk_tokens = 5 + monkeypatch.setattr( + deep_gemm_utils, + "_SM120_PAGED_MQA_TOPK_CHUNK_SIZE", + 7, + ) + monkeypatch.setattr( + torch, + "cat", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("paged MQA top-k should reuse candidate buffers") + ), + ) + + q = torch.randn( + batch_size, + next_n, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + q_fp8 = q.to(torch.float8_e4m3fn) + kv = torch.randn( + num_blocks, block_size, 1, head_dim, device="cuda", dtype=torch.bfloat16 + ) + kv_scale = kv.abs().float().amax(dim=-1, keepdim=True).clamp(1e-4) / 448.0 + kv_fp8 = (kv * kv_scale.reciprocal()).to(torch.float8_e4m3fn) + fused_kv = _make_packed_fp8_indexer_cache(kv_fp8, kv_scale) + + weights = torch.randn( + batch_size * next_n, num_heads, device="cuda", dtype=torch.float32 + ) + context_lens = torch.tensor([[3, 11], [17, 20]], device="cuda", dtype=torch.int32) + block_tables = ( + torch.arange( + batch_size * cdiv(max_model_len, block_size), + device="cuda", + dtype=torch.int32, + ).reshape(batch_size, -1) + % num_blocks + ) + topk_indices = torch.empty( + batch_size * next_n, topk_tokens, device="cuda", dtype=torch.int32 + ) + + assert deep_gemm_utils.fp8_fp4_paged_mqa_topk_indices( + (q_fp8.contiguous(), None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + topk_indices, + ) + + logits = deep_gemm_utils._fp8_paged_mqa_logits_torch( + (q_fp8.contiguous(), None), + fused_kv, + weights, + context_lens, + block_tables, + max_model_len, + ) + expected = torch.full_like(topk_indices, -1) + flat_context_lens = context_lens.reshape(-1) + for row in range(batch_size * next_n): + valid_count = int(flat_context_lens[row].item()) + row_topk = min(topk_tokens, valid_count) + if row_topk > 0: + expected[row, :row_topk] = ( + logits[row].topk(row_topk).indices.to(torch.int32) + ) + + for row in range(batch_size * next_n): + row_topk = min(topk_tokens, int(flat_context_lens[row].item())) + assert set(topk_indices[row, :row_topk].tolist()) == set( + expected[row, :row_topk].tolist() + ) + assert torch.all(topk_indices[row, row_topk:] == -1) + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_torch_path_streams_head_chunks( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(0) + seq_len, seq_len_kv, num_heads, head_dim = 9, 17, 32, 32 + monkeypatch.setattr( + deep_gemm_utils, + "_SM120_MQA_LOGITS_MAX_SCORE_BYTES", + seq_len * 5 * 4, + ) + + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3 + cu_seqlen_ke = torch.minimum( + torch.arange(seq_len, device="cuda", dtype=torch.int32) + 4, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + logits = deep_gemm_utils._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + ref_logits = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + ref_logits = ref_logits.masked_fill(~valid, float("-inf")) + + assert torch.equal(torch.isneginf(logits), torch.isneginf(ref_logits)) + finite = torch.isfinite(ref_logits) + assert (logits[finite] - ref_logits[finite]).abs().max() < 1e-4 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_wrapper_uses_triton_when_deepgemm_missing( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(2) + seq_len, seq_len_kv, num_heads, head_dim = 5, 13, 8, 32 + + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 3 + cu_seqlen_ke = torch.minimum( + cu_seqlen_ks + 6, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + kv_dequant = kv_fp8.float() * kv_scale[:, None] + score = torch.einsum("mhd,nd->hmn", q_fp8.float(), kv_dequant) + expected = (score.relu() * weights.transpose(0, 1).unsqueeze(-1)).sum(dim=0) + offsets = torch.arange(seq_len_kv, device="cuda") + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + expected = expected.masked_fill(~valid, float("-inf")) + + monkeypatch.setattr(deep_gemm_utils, "_lazy_init", lambda: None) + monkeypatch.setattr(deep_gemm_utils, "_fp8_fp4_mqa_logits_impl", None) + + def fail_torch_path(*args, **kwargs): + raise AssertionError("torch fallback should not be used") + + monkeypatch.setattr(deep_gemm_utils, "_fp8_mqa_logits_torch", fail_torch_path) + actual = deep_gemm_utils.fp8_fp4_mqa_logits( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + + assert torch.equal(torch.isneginf(actual), torch.isneginf(expected)) + finite = torch.isfinite(expected) + assert (actual[finite] - expected[finite]).abs().max() < 2e-2 + + +@pytest.mark.skipif( + not current_platform.is_device_capability_family(120), reason="SM120 only" +) +def test_sm120_fp8_mqa_logits_topk_streams_k_chunks( + monkeypatch: pytest.MonkeyPatch, +): + torch.manual_seed(1) + seq_len, seq_len_kv, num_heads, head_dim = 11, 23, 16, 32 + topk_tokens = 5 + monkeypatch.setattr( + deep_gemm_utils, + "_SM120_MQA_LOGITS_MAX_SCORE_BYTES", + seq_len * 5 * 4, + ) + monkeypatch.setattr( + torch, + "cat", + lambda *args, **kwargs: (_ for _ in ()).throw( + AssertionError("MQA top-k should reuse candidate buffers") + ), + ) + + q = torch.randn(seq_len, num_heads, head_dim, device="cuda", dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device="cuda", dtype=torch.float32) + cu_seqlen_ks = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 4 + valid_lens = torch.arange(seq_len, device="cuda", dtype=torch.int32) % 7 + cu_seqlen_ke = torch.minimum( + cu_seqlen_ks + valid_lens, + torch.full((seq_len,), seq_len_kv, device="cuda", dtype=torch.int32), + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_amax = kv.abs().float().amax(dim=1, keepdim=True).clamp(1e-4) + kv_scale = (kv_amax / 448.0).squeeze(1).contiguous() + kv_fp8 = (kv * (1.0 / kv_scale[:, None])).to(torch.float8_e4m3fn) + + topk_indices = deep_gemm_utils._fp8_mqa_logits_topk_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_tokens, + ) + + logits = deep_gemm_utils._fp8_mqa_logits_torch( + (q_fp8, None), + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True, + ) + expected = torch.full_like(topk_indices, -1) + for row in range(seq_len): + valid_count = int((cu_seqlen_ke[row] - cu_seqlen_ks[row]).item()) + row_topk = min(topk_tokens, valid_count) + if row_topk > 0: + expected[row, :row_topk] = ( + logits[row].topk(row_topk).indices.to(torch.int32) + ) + + for row in range(seq_len): + valid_count = int((cu_seqlen_ke[row] - cu_seqlen_ks[row]).item()) + row_topk = min(topk_tokens, valid_count) + assert set(topk_indices[row, :row_topk].tolist()) == set( + expected[row, :row_topk].tolist() + ) + assert torch.all(topk_indices[row, row_topk:] == -1) + + def _float_to_e8m0_truncate(f: float) -> float: """Simulate SM100's float -> e8m0 -> bf16 scale conversion. e8m0 format only stores the exponent (power of 2). @@ -218,8 +707,14 @@ def test_sparse_backend_decode_correctness( if not ok: pytest.skip(reason) elif backend_cls == FlashInferMLASparseBackend: - if not current_platform.has_device_capability(100): - pytest.skip("FlashInferMLASparseBackend requires SM 10.0 or higher") + capability = current_platform.get_device_capability() + if capability is None or not backend_cls.supports_compute_capability( + capability + ): + pytest.skip( + "FlashInferMLASparseBackend does not support " + f"{capability} on this platform" + ) batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] use_fp8_ds_mla_quantization = kv_cache_dtype == "fp8_ds_mla" @@ -781,6 +1276,119 @@ def test_split_indexer_prefill_chunks( assert out == expected +def test_sparse_indexer_max_logits_bytes_uses_sm12x_safe_default(monkeypatch): + monkeypatch.delenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", raising=False) + + assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 256 * 1024 * 1024 + assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 512 * 1024 * 1024 + + +def test_sparse_indexer_max_logits_bytes_honors_env_override(monkeypatch): + monkeypatch.setenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB", "384") + + assert sparse_indexer_max_logits_bytes(is_sm12x=True) == 384 * 1024 * 1024 + assert sparse_indexer_max_logits_bytes(is_sm12x=False) == 384 * 1024 * 1024 + + +def test_compute_global_topk_indices_supports_in_place_output(): + device = torch.device(DEVICE_TYPE) + block_size = 4 + topk_indices = torch.tensor( + [[0, 3, 4, -1], [2, 5, -1, -1], [1, 7, -1, -1]], + dtype=torch.int32, + device=device, + ) + token_to_req = torch.tensor([0, 1, 1], dtype=torch.int32, device=device) + block_table = torch.tensor( + [[10, 11, 12], [20, 21, 22]], dtype=torch.int32, device=device + ) + is_valid = torch.tensor([True, True, False], device=device) + + expected_indices = torch.tensor( + [ + [40, 43, 44, -1], + [82, 85, -1, -1], + [-1, -1, -1, -1], + ], + dtype=torch.int32, + device=device, + ) + expected_lens = torch.tensor([3, 2, 0], dtype=torch.int32, device=device) + + out, lens = compute_global_topk_indices_and_lens( + topk_indices, + token_to_req, + block_table, + block_size, + is_valid, + ) + torch.testing.assert_close(out, expected_indices, rtol=0, atol=0) + torch.testing.assert_close(lens, expected_lens, rtol=0, atol=0) + + in_place = topk_indices.clone() + provided_lens = torch.empty(3, dtype=torch.int32, device=device) + out, lens = compute_global_topk_indices_and_lens( + in_place, + token_to_req, + block_table, + block_size, + is_valid, + global_topk_indices=in_place, + topk_lens=provided_lens, + ) + assert out is in_place + assert lens is provided_lens + torch.testing.assert_close(in_place, expected_indices, rtol=0, atol=0) + torch.testing.assert_close(provided_lens, expected_lens, rtol=0, atol=0) + + +def test_combine_topk_swa_indices_supports_workspace_outputs(): + device = torch.device(DEVICE_TYPE) + num_tokens = 6 + topk = 4 + window_size = 8 + topk_indices = ( + torch.arange(num_tokens * topk, dtype=torch.int32, device=device) + .reshape(num_tokens, topk) + .remainder(5) + ) + query_start_loc = torch.tensor([0, num_tokens], dtype=torch.int32, device=device) + seq_lens = torch.tensor([20], dtype=torch.int32, device=device) + gather_lens = torch.tensor([8], dtype=torch.int32, device=device) + + expected_indices, expected_lens = combine_topk_swa_indices( + topk_indices, + query_start_loc, + seq_lens, + gather_lens, + window_size, + 4, + topk, + 16, + 12, + ) + workspace_indices = torch.empty_like(expected_indices) + workspace_lens = torch.empty_like(expected_lens) + actual_indices, actual_lens = combine_topk_swa_indices( + topk_indices, + query_start_loc, + seq_lens, + gather_lens, + window_size, + 4, + topk, + 16, + 12, + combined_indices=workspace_indices, + combined_lens=workspace_lens, + ) + + assert actual_indices.data_ptr() == workspace_indices.data_ptr() + assert actual_lens.data_ptr() == workspace_lens.data_ptr() + torch.testing.assert_close(actual_indices, expected_indices, rtol=0, atol=0) + torch.testing.assert_close(actual_lens, expected_lens, rtol=0, atol=0) + + def test_split_indexer_prefill_chunks_single_request_overflow(): """Test that single request exceeding budget is sub-chunked on query dim.""" seq_lens = torch.tensor([1000, 50]) diff --git a/tests/v1/attention/test_sparse_mla_env.py b/tests/v1/attention/test_sparse_mla_env.py new file mode 100644 index 000000000000..9745ea169cc2 --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_env.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import os +from collections.abc import Iterator +from contextlib import contextmanager + +import torch + +from vllm.envs import environment_variables +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + triton_sparse_mla_cudagraphs_allowed, + triton_sparse_mla_head_block_size, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) + +_SPARSE_MLA_ENV_NAMES = ( + "VLLM_TRITON_MLA_SPARSE", + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE", +) + + +@contextmanager +def _patched_sparse_mla_env(**updates: str) -> Iterator[None]: + previous = {name: os.environ.get(name) for name in _SPARSE_MLA_ENV_NAMES} + try: + for name in _SPARSE_MLA_ENV_NAMES: + os.environ.pop(name, None) + os.environ.update(updates) + yield + finally: + for name, value in previous.items(): + if value is None: + os.environ.pop(name, None) + else: + os.environ[name] = value + + +def test_triton_sparse_mla_env_uses_new_name() -> None: + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="0"): + assert not is_triton_sparse_mla_enabled(torch.device("cpu")) + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE="1"): + assert is_triton_sparse_mla_enabled(torch.device("cpu")) + + +def test_sparse_mla_cudagraph_env_defaults_to_allowed() -> None: + with _patched_sparse_mla_env(): + assert triton_sparse_mla_cudagraphs_allowed() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="0"): + assert not triton_sparse_mla_cudagraphs_allowed() + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH="1"): + assert triton_sparse_mla_cudagraphs_allowed() + + +def test_sparse_mla_head_block_env_accepts_supported_values() -> None: + with _patched_sparse_mla_env(): + assert triton_sparse_mla_head_block_size() is None + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="1"): + assert triton_sparse_mla_head_block_size() == 1 + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="2"): + assert triton_sparse_mla_head_block_size() == 2 + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"): + assert triton_sparse_mla_head_block_size() == 4 + + +def test_sparse_mla_head_block_env_ignores_invalid_values() -> None: + for value in ("0", "3", "invalid"): + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE=value): + assert triton_sparse_mla_head_block_size() is None + + +def test_sparse_mla_head_block_env_is_registered_with_vllm_envs() -> None: + assert "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" in environment_variables + + with _patched_sparse_mla_env(VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE="4"): + assert environment_variables["VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE"]() == 4 + + +def test_sparse_mla_chunk_env_defaults_invalid_values() -> None: + with _patched_sparse_mla_env( + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE="invalid", + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE="-7", + ): + assert triton_sparse_mla_topk_chunk_size() == 512 + assert triton_sparse_mla_query_chunk_size() == 1 diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index c35c38911a1a..c9f9fdd58fe2 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -36,6 +36,8 @@ KVCacheConfig, KVCacheGroupSpec, MambaSpec, + MLAAttentionSpec, + SlidingWindowMLASpec, SlidingWindowSpec, ) @@ -2573,6 +2575,215 @@ def test_can_fit_full_sequence_swa_cap_admits_long_prompt(): ) +def test_deepseek_v4_mla_keeps_hybrid_aligned_prompt_blocks_after_decode(): + hash_block_size = 2 + full_block_size = 8 + swa_block_size = 2 + prompt_tokens = 35 + chunk_tokens = 4 * full_block_size + expected_hit_tokens = ( + (prompt_tokens - 1) // full_block_size * full_block_size + ) + + config = KVCacheConfig( + num_blocks=70, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=full_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_0"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_1"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + sliding_window=2 * swa_block_size, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ), + KVCacheGroupSpec( + ["layer_swa_mla_compressor_state"], + SlidingWindowMLASpec( + block_size=swa_block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.float32, + sliding_window=2 * swa_block_size, + ), + ), + ], + ) + manager = KVCacheManager( + config, + max_model_len=128, + max_num_batched_tokens=chunk_tokens, + enable_caching=True, + hash_block_size=hash_block_size, + ) + + def run_request(request: Request, num_decode_tokens: int) -> int: + computed_blocks, num_computed_tokens = manager.get_computed_blocks(request) + computed_so_far = num_computed_tokens + remaining_prompt_tokens = request.num_prompt_tokens - num_computed_tokens + first_chunk = True + while remaining_prompt_tokens > 0: + num_new_tokens = min(chunk_tokens, remaining_prompt_tokens) + allocated = manager.allocate_slots( + request, + num_new_tokens, + num_computed_tokens if first_chunk else 0, + computed_blocks if first_chunk else None, + ) + assert allocated is not None + computed_so_far += num_new_tokens + request.num_computed_tokens = computed_so_far + remaining_prompt_tokens -= num_new_tokens + first_chunk = False + + for i in range(num_decode_tokens): + request.append_output_token_ids(10_000 + i) + allocated = manager.allocate_slots(request, 1) + assert allocated is not None + computed_so_far += 1 + request.num_computed_tokens = computed_so_far + return num_computed_tokens + + prompt_a = list(range(prompt_tokens)) + req_a = make_request("a", prompt_a, hash_block_size, sha256) + assert run_request(req_a, num_decode_tokens=0) == 0 + manager.free(req_a) + + warm_a = make_request("warm_a", prompt_a, hash_block_size, sha256) + assert run_request(warm_a, num_decode_tokens=8) == expected_hit_tokens + assert manager.get_num_common_prefix_blocks("warm_a")[0] >= ( + expected_hit_tokens // full_block_size + ) + manager.free(warm_a) + + pressure_blocks = manager.block_pool.get_new_blocks( + manager.block_pool.get_num_free_blocks() + ) + manager.block_pool.free_blocks(reversed(pressure_blocks)) + + req_a_again = make_request("a_again", prompt_a, hash_block_size, sha256) + _, num_computed_tokens = manager.get_computed_blocks(req_a_again) + assert num_computed_tokens == expected_hit_tokens + + +def test_deepseek_v4_mla_protected_prompt_blocks_do_not_block_admission(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + protected_blocks_per_prompt = (prompt_tokens - 1) // block_size + num_prompts = 10 + num_blocks = 80 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=num_blocks, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + mla_manager = manager.coordinator.single_type_managers[0] + + for i in range(num_prompts): + prompt = list(range(i * 1000, i * 1000 + prompt_tokens)) + req = make_request(f"protected_{i}", prompt, block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert len(mla_manager._protected_prompt_block_ids) == ( + num_prompts * protected_blocks_per_prompt + ) + assert manager.block_pool.get_num_free_blocks() < 64 + + long_req = make_request( + "long", + list(range(100_000, 100_000 + 64 * block_size)), + block_size, + sha256, + ) + assert ( + manager.allocate_slots(long_req, block_size, full_sequence_must_fit=True) + is not None + ) + + +def test_reset_prefix_cache_releases_deepseek_v4_mla_protected_blocks(): + block_size = 8 + prompt_tokens = 4 * block_size + 3 + manager = KVCacheManager( + KVCacheConfig( + num_blocks=32, + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec( + ["layer_full"], + MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=1, + dtype=torch.uint8, + cache_dtype_str="fp8_ds_mla", + model_version="deepseek_v4", + ), + ) + ], + ), + max_model_len=512, + max_num_batched_tokens=128, + enable_caching=True, + hash_block_size=block_size, + ) + + req = make_request("protected", list(range(prompt_tokens)), block_size, sha256) + assert manager.allocate_slots(req, prompt_tokens) is not None + req.num_computed_tokens = prompt_tokens + manager.free(req) + + assert manager.coordinator.single_type_managers[0]._protected_prompt_block_ids + assert manager.reset_prefix_cache() + + def test_can_fit_full_sequence_full_attention_still_gates_oversized(): """The cap only loosens the SWA group; a prompt that exceeds the full-attention pool capacity must still be rejected.""" diff --git a/tests/v1/executor/test_ray_utils.py b/tests/v1/executor/test_ray_utils.py index 8da9d5459e73..a83b513baca8 100644 --- a/tests/v1/executor/test_ray_utils.py +++ b/tests/v1/executor/test_ray_utils.py @@ -1,8 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import numpy as np +from vllm.v1.executor import ray_utils from vllm.v1.executor.ray_utils import detach_zero_copy_from_model_runner_output from vllm.v1.outputs import LogprobsLists, LogprobsTensors, ModelRunnerOutput @@ -52,3 +55,46 @@ def test_detach_zero_copy_from_model_runner_output_copies_only_numpy_views(): assert detached_logprobs.sampled_token_ranks.flags.writeable assert detached_logprobs.cu_num_generated_tokens is cu_num_generated_tokens assert output.prompt_logprobs_dict["req-0"] is prompt_logprobs + + +def test_cluster_device_warning_uses_ray_cluster_resources(monkeypatch): + warnings = [] + + monkeypatch.setattr( + ray_utils, + "ray", + SimpleNamespace(cluster_resources=lambda: {"GPU": 2}), + ) + monkeypatch.setattr( + ray_utils.logger, + "warning", + lambda *args, **kwargs: warnings.append(args), + ) + + ray_utils._warn_if_insufficient_cluster_devices( + SimpleNamespace(world_size=2), "GPU" + ) + + assert warnings == [] + + +def test_cluster_device_warning_reports_cluster_shortage(monkeypatch): + warnings = [] + + monkeypatch.setattr( + ray_utils, + "ray", + SimpleNamespace(cluster_resources=lambda: {"GPU": 1}), + ) + monkeypatch.setattr( + ray_utils.logger, + "warning", + lambda *args, **kwargs: warnings.append(args), + ) + + ray_utils._warn_if_insufficient_cluster_devices( + SimpleNamespace(world_size=2), "GPU" + ) + + assert len(warnings) == 1 + assert "distributed world size" in warnings[0][0] diff --git a/tools/compare_vllm_http_logprobs_oracle.py b/tools/compare_vllm_http_logprobs_oracle.py new file mode 100755 index 000000000000..85a3d6aade55 --- /dev/null +++ b/tools/compare_vllm_http_logprobs_oracle.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Compare current vLLM HTTP logprobs with a captured oracle bundle. + +The expected oracle format is a directory with request_*.json and matching +response_*.json files produced from /v1/completions with token-id logprob keys. +""" + +from __future__ import annotations + +import argparse +import json +import sys +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + +Json = dict[str, Any] + + +class TokenNormalizer: + """Normalize token-id placeholders to decoded token strings.""" + + def __init__(self, decode_token_id): + self._decode_token_id = decode_token_id + + def token_key(self, token: Any) -> str: + key = _token_key(token) + if not key.startswith("token_id:"): + return key + try: + token_id = int(key.removeprefix("token_id:")) + except ValueError: + return key + return self._decode_token_id(token_id) + + +def _token_key(token: Any) -> str: + if isinstance(token, str): + return token + if isinstance(token, int): + return f"token_id:{token}" + return str(token) + + +def _choice(response: Json) -> Json: + choices = response.get("choices") + if not isinstance(choices, list) or not choices: + raise ValueError("response has no choices[0]") + choice = choices[0] + if not isinstance(choice, dict): + raise ValueError("response choices[0] is not an object") + return choice + + +def _normalize_token(token: Any, normalizer: TokenNormalizer | None) -> str: + if normalizer is None: + return _token_key(token) + return normalizer.token_key(token) + + +def _generated_tokens( + response: Json, normalizer: TokenNormalizer | None = None +) -> list[str]: + choice = _choice(response) + logprobs = choice.get("logprobs") or {} + tokens = logprobs.get("tokens") + if isinstance(tokens, list) and tokens: + return [_normalize_token(token, normalizer) for token in tokens] + token_ids = choice.get("token_ids") + if isinstance(token_ids, list): + return [_normalize_token(token, normalizer) for token in token_ids] + return [] + + +def _prompt_token_ids(response: Json) -> list[int] | None: + token_ids = _choice(response).get("prompt_token_ids") + if not isinstance(token_ids, list): + return None + return [int(token) for token in token_ids] + + +def _token_logprobs(response: Json) -> list[float | None]: + logprobs = _choice(response).get("logprobs") or {} + values = logprobs.get("token_logprobs") + if not isinstance(values, list): + return [] + out: list[float | None] = [] + for value in values: + out.append(float(value) if value is not None else None) + return out + + +def _top_logprobs( + response: Json, normalizer: TokenNormalizer | None = None +) -> list[dict[str, float]]: + logprobs = _choice(response).get("logprobs") or {} + values = logprobs.get("top_logprobs") + if not isinstance(values, list): + return [] + out: list[dict[str, float]] = [] + for step in values: + if not isinstance(step, dict): + out.append({}) + continue + out.append( + { + _normalize_token(token, normalizer): float(logprob) + for token, logprob in step.items() + } + ) + return out + + +def _first_mismatch(oracle_tokens: list[str], actual_tokens: list[str]) -> Json | None: + for step, (oracle_token, actual_token) in enumerate( + zip(oracle_tokens, actual_tokens) + ): + if oracle_token != actual_token: + return {"step": step, "oracle": oracle_token, "actual": actual_token} + if len(oracle_tokens) != len(actual_tokens): + step = min(len(oracle_tokens), len(actual_tokens)) + return { + "step": step, + "oracle": oracle_tokens[step] if step < len(oracle_tokens) else None, + "actual": actual_tokens[step] if step < len(actual_tokens) else None, + } + return None + + +def _matching_prefix(oracle_tokens: list[str], actual_tokens: list[str]) -> int: + count = 0 + for oracle_token, actual_token in zip(oracle_tokens, actual_tokens): + if oracle_token != actual_token: + break + count += 1 + return count + + +def _top_keys(top_logprobs: dict[str, float], top_n: int) -> list[str]: + return list(top_logprobs.keys())[:top_n] + + +def _mean(values: list[float]) -> float | None: + return sum(values) / len(values) if values else None + + +def _max(values: list[float]) -> float | None: + return max(values) if values else None + + +def compare_response( + case_name: str, + oracle_response: Json, + actual_response: Json, + *, + top_n: int = 50, + normalizer: TokenNormalizer | None = None, +) -> Json: + oracle_tokens = _generated_tokens(oracle_response, normalizer) + actual_tokens = _generated_tokens(actual_response, normalizer) + oracle_top = _top_logprobs(oracle_response, normalizer) + actual_top = _top_logprobs(actual_response, normalizer) + oracle_token_logprobs = _token_logprobs(oracle_response) + actual_token_logprobs = _token_logprobs(actual_response) + + steps = min( + len(oracle_tokens), len(actual_tokens), len(oracle_top), len(actual_top) + ) + top1_matches = 0 + topk_overlaps: list[float] = [] + common_logprob_errors: list[float] = [] + chosen_token_logprob_errors: list[float] = [] + + for step in range(steps): + oracle_keys = _top_keys(oracle_top[step], top_n) + actual_keys = _top_keys(actual_top[step], top_n) + if oracle_keys and actual_keys and oracle_keys[0] == actual_keys[0]: + top1_matches += 1 + + oracle_set = set(oracle_keys) + actual_set = set(actual_keys) + if oracle_set: + topk_overlaps.append(len(oracle_set & actual_set) / len(oracle_set)) + for token in oracle_set & actual_set: + common_logprob_errors.append( + abs(oracle_top[step][token] - actual_top[step][token]) + ) + + if ( + oracle_tokens[step] == actual_tokens[step] + and step < len(oracle_token_logprobs) + and step < len(actual_token_logprobs) + and oracle_token_logprobs[step] is not None + and actual_token_logprobs[step] is not None + ): + chosen_token_logprob_errors.append( + abs(oracle_token_logprobs[step] - actual_token_logprobs[step]) + ) + + oracle_prompt_ids = _prompt_token_ids(oracle_response) + actual_prompt_ids = _prompt_token_ids(actual_response) + prompt_ids_match = ( + None + if oracle_prompt_ids is None or actual_prompt_ids is None + else oracle_prompt_ids == actual_prompt_ids + ) + + first_mismatch = _first_mismatch(oracle_tokens, actual_tokens) + return { + "case": case_name, + "tokens_match": first_mismatch is None, + "prompt_token_ids_match": prompt_ids_match, + "first_token_mismatch": first_mismatch, + "matching_prefix_tokens": _matching_prefix(oracle_tokens, actual_tokens), + "oracle_token_count": len(oracle_tokens), + "actual_token_count": len(actual_tokens), + "compared_steps": steps, + "top1_matches": top1_matches, + "top1_match_rate": top1_matches / steps if steps else None, + "topk_overlap_mean": _mean(topk_overlaps), + "topk_overlap_min": min(topk_overlaps) if topk_overlaps else None, + "max_common_logprob_abs_error": _max(common_logprob_errors), + "mean_common_logprob_abs_error": _mean(common_logprob_errors), + "max_chosen_token_logprob_abs_error": _max(chosen_token_logprob_errors), + "mean_chosen_token_logprob_abs_error": _mean(chosen_token_logprob_errors), + } + + +def _load_json(path: Path) -> Json: + with path.open(encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise ValueError(f"{path} is not a JSON object") + return data + + +def load_oracle_cases(oracle_dir: Path) -> list[tuple[str, Json, Json]]: + cases: list[tuple[str, Json, Json]] = [] + request_paths = sorted(oracle_dir.glob("request_*.json")) + if not request_paths: + raise ValueError(f"{oracle_dir} has no request_*.json files") + for request_path in request_paths: + suffix = request_path.stem.removeprefix("request_") + response_path = oracle_dir / f"response_{suffix}.json" + if not response_path.exists(): + raise ValueError(f"missing {response_path.name} for {request_path.name}") + cases.append((suffix, _load_json(request_path), _load_json(response_path))) + return cases + + +def post_completion(base_url: str, payload: Json, timeout: float) -> Json: + url = f"{base_url.rstrip('/')}/v1/completions" + encoded = json.dumps(payload).encode("utf-8") + request = urllib.request.Request( + url, + data=encoded, + headers={"Content-Type": "application/json"}, + method="POST", + ) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + body = response.read().decode("utf-8") + except urllib.error.HTTPError as exc: + body = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"HTTP {exc.code} from {url}: {body}") from exc + data = json.loads(body) + if not isinstance(data, dict): + raise ValueError(f"{url} returned non-object JSON") + return data + + +def load_token_normalizer( + tokenizer: str, + *, + tokenizer_mode: str, + trust_remote_code: bool, +) -> TokenNormalizer: + from vllm.tokenizers import get_tokenizer + + hf_tokenizer = get_tokenizer( + tokenizer, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + ) + + def decode_token_id(token_id: int) -> str: + return hf_tokenizer.decode([token_id]) + + return TokenNormalizer(decode_token_id) + + +def summarize_reports(reports: list[Json]) -> Json: + return { + "case_count": len(reports), + "all_tokens_match": all(report["tokens_match"] for report in reports), + "all_prompt_token_ids_match": all( + report["prompt_token_ids_match"] is not False for report in reports + ), + "min_top1_match_rate": min( + ( + report["top1_match_rate"] + for report in reports + if report["top1_match_rate"] is not None + ), + default=None, + ), + "min_topk_overlap_mean": min( + ( + report["topk_overlap_mean"] + for report in reports + if report["topk_overlap_mean"] is not None + ), + default=None, + ), + "max_common_logprob_abs_error": max( + ( + report["max_common_logprob_abs_error"] + for report in reports + if report["max_common_logprob_abs_error"] is not None + ), + default=None, + ), + "max_chosen_token_logprob_abs_error": max( + ( + report["max_chosen_token_logprob_abs_error"] + for report in reports + if report["max_chosen_token_logprob_abs_error"] is not None + ), + default=None, + ), + } + + +def _fails_thresholds( + summary: Json, reports: list[Json], args: argparse.Namespace +) -> bool: + failed = False + if args.strict_tokens and not summary["all_tokens_match"]: + failed = True + if args.strict_prompt_token_ids and not summary["all_prompt_token_ids_match"]: + failed = True + if args.min_top1_match_rate is not None: + value = summary["min_top1_match_rate"] + failed = failed or value is None or value < args.min_top1_match_rate + if args.min_topk_overlap_mean is not None: + value = summary["min_topk_overlap_mean"] + failed = failed or value is None or value < args.min_topk_overlap_mean + if args.max_common_logprob_abs_error is not None: + value = summary["max_common_logprob_abs_error"] + failed = failed or value is None or value > args.max_common_logprob_abs_error + if args.max_chosen_token_logprob_abs_error is not None: + value = summary["max_chosen_token_logprob_abs_error"] + failed = ( + failed or value is None or value > args.max_chosen_token_logprob_abs_error + ) + if args.fail_on_first_mismatch: + failed = failed or any(report["first_token_mismatch"] for report in reports) + return failed + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--oracle-dir", required=True, type=Path) + parser.add_argument("--base-url", default="http://127.0.0.1:8000") + parser.add_argument("--timeout", type=float, default=240.0) + parser.add_argument("--top-n", type=int, default=50) + parser.add_argument("--output", type=Path) + parser.add_argument("--model", help="Override the model field from request_*.json") + parser.add_argument( + "--tokenizer", + help=( + "Tokenizer used to decode token_id: logprob keys before " + "comparing them with text token keys." + ), + ) + parser.add_argument("--tokenizer-mode", default="auto") + parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--strict-tokens", action="store_true") + parser.add_argument("--strict-prompt-token-ids", action="store_true") + parser.add_argument("--fail-on-first-mismatch", action="store_true") + parser.add_argument("--min-top1-match-rate", type=float) + parser.add_argument("--min-topk-overlap-mean", type=float) + parser.add_argument("--max-common-logprob-abs-error", type=float) + parser.add_argument("--max-chosen-token-logprob-abs-error", type=float) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + cases = load_oracle_cases(args.oracle_dir) + normalizer = None + if args.tokenizer: + normalizer = load_token_normalizer( + args.tokenizer, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) + reports: list[Json] = [] + for suffix, request_payload, oracle_response in cases: + payload = dict(request_payload) + if args.model: + payload["model"] = args.model + actual_response = post_completion(args.base_url, payload, args.timeout) + reports.append( + compare_response( + f"request_{suffix}", + oracle_response, + actual_response, + top_n=args.top_n, + normalizer=normalizer, + ) + ) + + summary = summarize_reports(reports) + result: Json = {"summary": summary, "cases": reports} + text = json.dumps(result, indent=2, sort_keys=True) + if args.output: + args.output.write_text(text + "\n", encoding="utf-8") + print(text) + return 1 if _fails_thresholds(summary, reports, args) else 0 + + +if __name__ == "__main__": + try: + raise SystemExit(main()) + except (OSError, RuntimeError, ValueError, json.JSONDecodeError) as exc: + print(f"error: {exc}", file=sys.stderr) + raise SystemExit(2) from exc diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 0f02a92681c1..91e5a8913e88 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -750,6 +750,7 @@ class CompilationConfig: "vllm::sparse_attn_indexer", "vllm::rocm_aiter_sparse_attn_indexer", "vllm::deepseek_v4_attention", + "vllm::deepseek_v4_fp8_einsum", ] def compute_hash(self) -> str: diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index cfe0857b679e..78de5bf2ef1e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -399,6 +399,12 @@ class ConversationMessage(TypedDict, total=False): reasoning_content: str | None """Deprecated: The reasoning content for interleaved thinking.""" + prefix: bool + """Whether this assistant message is a prefix for continuation.""" + + wo_eos: bool + """Whether this message should be rendered without an EOS marker.""" + tools: list[ChatCompletionFunctionToolParam] | None """The tools for developer role.""" @@ -1751,6 +1757,8 @@ def _parse_chat_message_content( role = message["role"] content = message.get("content") reasoning = message.get("reasoning") + if reasoning is None: + reasoning = message.get("reasoning_content") if content is None: content = [] @@ -1780,6 +1788,9 @@ def _parse_chat_message_content( result_msg["reasoning_content"] = cast( str, reasoning ) # keep compatibility + if parsed_msg.get("prefix"): + result_msg["prefix"] = True + result_msg["wo_eos"] = True elif role == "tool": parsed_msg = _ToolParser(message) if "tool_call_id" in parsed_msg: diff --git a/vllm/entrypoints/openai/chat_completion/batch_serving.py b/vllm/entrypoints/openai/chat_completion/batch_serving.py index cc49909b8361..df6aff4a45c3 100644 --- a/vllm/entrypoints/openai/chat_completion/batch_serving.py +++ b/vllm/entrypoints/openai/chat_completion/batch_serving.py @@ -159,8 +159,12 @@ async def create_batch_chat_completion( self.override_max_tokens, ) single_request = single_requests[i] + chat_template_kwargs = self._effective_chat_template_kwargs(single_request) sampling_params = single_request.to_sampling_params( - max_tokens, self.default_sampling_params + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( sub_request_id, diff --git a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py index c92cc13da01f..b0a73770da35 100644 --- a/vllm/entrypoints/openai/chat_completion/protocol.py +++ b/vllm/entrypoints/openai/chat_completion/protocol.py @@ -62,6 +62,15 @@ class ChatMessage(OpenAIBaseModel): # vLLM-specific fields that are not in OpenAI spec reasoning: str | None = None + reasoning_content: str | None = None + + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "ChatMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self class ChatCompletionLogProb(OpenAIBaseModel): @@ -164,6 +173,10 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +class DeepSeekThinkingParam(OpenAIBaseModel): + type: Literal["enabled", "disabled"] = "enabled" + + class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -209,6 +222,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "part of the standard OpenAI API specification." ), ) + thinking: DeepSeekThinkingParam | None = None + deepseek_v4_sampling_override: bool = Field( + default=True, + description=( + "Apply DeepSeek V4 official sampling defaults when thinking is " + "enabled. This only affects the DeepSeek V4 family and can be " + "disabled per request." + ), + ) thinking_token_budget: int | None = None include_reasoning: bool = True parallel_tool_calls: bool | None = True @@ -435,21 +457,79 @@ def build_chat_params( default_template: str | None, default_template_content_format: ChatTemplateContentFormatOption, ) -> ChatParams: + chat_kwargs = merge_kwargs( + self.chat_template_kwargs, + dict( + add_generation_prompt=self.add_generation_prompt, + continue_final_message=self.continue_final_message, + documents=self.documents, + reasoning_effort=self.reasoning_effort, + ), + ) return ChatParams( chat_template=self.chat_template or default_template, chat_template_content_format=default_template_content_format, - chat_template_kwargs=merge_kwargs( - self.chat_template_kwargs, - dict( - add_generation_prompt=self.add_generation_prompt, - continue_final_message=self.continue_final_message, - documents=self.documents, - reasoning_effort=self.reasoning_effort, - ), - ), + chat_template_kwargs=chat_kwargs, media_io_kwargs=self.media_io_kwargs, ) + def _is_deepseek_v4_model(self, model_config: ModelConfig | None = None) -> bool: + hf_config = getattr(model_config, "hf_config", None) + if getattr(hf_config, "model_type", None) == "deepseek_v4": + return True + + architectures = getattr(hf_config, "architectures", None) or () + if any("deepseekv4" in str(arch).replace("_", "").lower() + for arch in architectures): + return True + + model = (self.model or "").lower().replace("_", "-") + return "deepseek-v4" in model + + def apply_chat_template_kwargs( + self, + chat_template_kwargs: dict[str, Any], + *, + model_config: ModelConfig | None = None, + ) -> dict[str, Any]: + """Apply request-level DeepSeek API compatibility knobs. + + DeepSeek's OpenAI-compatible API exposes ``thinking`` as a top-level + request field, while vLLM's DeepSeek tokenizer consumes it as a chat + template kwarg. Keep the translation at the protocol boundary so the + tokenizer and reasoning parser see the same effective state. + """ + chat_template_kwargs = dict(chat_template_kwargs) + if not self._is_deepseek_v4_model(model_config): + return chat_template_kwargs + + if self.thinking is not None: + enabled = self.thinking.type == "enabled" + chat_template_kwargs["thinking"] = enabled + chat_template_kwargs["enable_thinking"] = enabled + elif ( + "thinking" not in chat_template_kwargs + and "enable_thinking" not in chat_template_kwargs + ): + chat_template_kwargs["thinking"] = True + chat_template_kwargs["enable_thinking"] = True + + return chat_template_kwargs + + def _use_deepseek_v4_sampling_override(self) -> bool: + return self.deepseek_v4_sampling_override + + @staticmethod + def _is_thinking_enabled( + chat_template_kwargs: dict[str, Any] | None, + ) -> bool: + if chat_template_kwargs is None: + return False + return bool( + chat_template_kwargs.get("thinking") + or chat_template_kwargs.get("enable_thinking") + ) + def build_tok_params(self, model_config: ModelConfig) -> TokenizeParams: if self.max_completion_tokens is not None: max_output_tokens: int | None = self.max_completion_tokens @@ -499,6 +579,9 @@ def to_sampling_params( self, max_tokens: int, default_sampling_params: dict, + *, + chat_template_kwargs: dict[str, Any] | None = None, + model_config: ModelConfig | None = None, ) -> SamplingParams: # Default parameters if (repetition_penalty := self.repetition_penalty) is None: @@ -523,6 +606,21 @@ def to_sampling_params( "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] ) + if ( + self._is_deepseek_v4_model(model_config) + and self._use_deepseek_v4_sampling_override() + and self._is_thinking_enabled(chat_template_kwargs) + ): + temperature = self._DEFAULT_SAMPLING_PARAMS["temperature"] + top_p = self._DEFAULT_SAMPLING_PARAMS["top_p"] + top_k = self._DEFAULT_SAMPLING_PARAMS["top_k"] + min_p = self._DEFAULT_SAMPLING_PARAMS["min_p"] + presence_penalty = 0.0 + frequency_penalty = 0.0 + else: + presence_penalty = self.presence_penalty or 0.0 + frequency_penalty = self.frequency_penalty or 0.0 + prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: prompt_logprobs = self.top_logprobs @@ -565,8 +663,8 @@ def to_sampling_params( extra_args["kv_transfer_params"] = self.kv_transfer_params return SamplingParams.from_optional( n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, repetition_penalty=repetition_penalty, temperature=temperature, top_p=top_p, diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 694ff80047c7..2dad7f948b5b 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -190,7 +190,7 @@ def warmup(self) -> None: def _effective_chat_template_kwargs( self, request: ChatCompletionRequest ) -> dict[str, Any]: - return ( + chat_template_kwargs = ( request.build_chat_params( self.chat_template, self.chat_template_content_format, @@ -198,6 +198,10 @@ def _effective_chat_template_kwargs( .with_defaults(self.default_chat_template_kwargs) .chat_template_kwargs ) + return request.apply_chat_template_kwargs( + chat_template_kwargs, + model_config=self.model_config, + ) async def render_chat_request( self, @@ -300,6 +304,8 @@ async def create_chat_completion( sampling_params = request.to_sampling_params( max_tokens, self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, ) self._log_inputs( diff --git a/vllm/entrypoints/openai/engine/protocol.py b/vllm/entrypoints/openai/engine/protocol.py index 890af0300efc..8c531c6d5a93 100644 --- a/vllm/entrypoints/openai/engine/protocol.py +++ b/vllm/entrypoints/openai/engine/protocol.py @@ -268,8 +268,17 @@ class DeltaMessage(OpenAIBaseModel): role: str | None = None content: str | None = None reasoning: str | None = None + reasoning_content: str | None = None tool_calls: list[DeltaToolCall] = Field(default_factory=list) + @model_validator(mode="after") + def _populate_reasoning_content_alias(self) -> "DeltaMessage": + if self.reasoning_content is None and self.reasoning is not None: + self.reasoning_content = self.reasoning + elif self.reasoning is None and self.reasoning_content is not None: + self.reasoning = self.reasoning_content + return self + class GenerationError(Exception): """raised when finish_reason indicates internal server error (500)""" diff --git a/vllm/entrypoints/serve/render/serving.py b/vllm/entrypoints/serve/render/serving.py index 967899229ada..d6da0dba6f4e 100644 --- a/vllm/entrypoints/serve/render/serving.py +++ b/vllm/entrypoints/serve/render/serving.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence +from dataclasses import replace from http import HTTPStatus from typing import Any, cast @@ -165,7 +166,21 @@ async def render_chat_request( self.default_sampling_params, self.override_max_tokens, ) - params = request.to_sampling_params(max_tokens, self.default_sampling_params) + chat_template_kwargs = request.apply_chat_template_kwargs( + request.build_chat_params( + self.chat_template, + self.chat_template_content_format, + ) + .with_defaults(self.default_chat_template_kwargs) + .chat_template_kwargs, + model_config=self.model_config, + ) + params = request.to_sampling_params( + max_tokens, + self.default_sampling_params, + chat_template_kwargs=chat_template_kwargs, + model_config=self.model_config, + ) request_id = f"chatcmpl-{random_uuid()}" @@ -556,6 +571,17 @@ async def preprocess_chat( default_media_io_kwargs=(mm_config.media_io_kwargs if mm_config else None), default_mm_processor_kwargs=getattr(request, "mm_processor_kwargs", None), ) + apply_chat_template_kwargs = getattr( + request, "apply_chat_template_kwargs", None + ) + if apply_chat_template_kwargs is not None: + chat_params = replace( + chat_params, + chat_template_kwargs=apply_chat_template_kwargs( + chat_params.chat_template_kwargs, + model_config=self.model_config, + ), + ) (conversation,), (engine_input,) = await renderer.render_chat_async( [messages], diff --git a/vllm/envs.py b/vllm/envs.py index ded474dc085a..07a3e55025de 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -166,6 +166,12 @@ VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES: bool = True + VLLM_TRITON_MLA_SPARSE: bool | None = None + VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE: int = 512 + VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE: int = 256 + VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH: bool = True + VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE: int | None = None + VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE: bool | None = None VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", "full", @@ -249,6 +255,7 @@ VLLM_MULTI_STREAM_GEMM_TOKEN_THRESHOLD: int = 1024 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_DEEPSEEK_V4_USE_MEGA_MOE: bool = False VLLM_LOG_MODEL_INSPECTION: bool = False VLLM_DEBUG_MFU_METRICS: bool = False VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY: bool = False @@ -1275,6 +1282,34 @@ def _get_or_set_default() -> str: "VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_TMA_ALIGNED_SCALES", "1")) ), + # Experimental sparse MLA fallback controls. + # ``VLLM_TRITON_MLA_SPARSE`` unset means auto-select where FlashMLA sparse + # is unavailable; set 0/1 to force-disable/force-enable the fallback. + "VLLM_TRITON_MLA_SPARSE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE") is None + else os.getenv("VLLM_TRITON_MLA_SPARSE", "").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE", "512") + ), + "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE", "256") + ), + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH": lambda: ( + os.getenv("VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH", "1").lower() + in ("1", "true", "yes", "on") + ), + "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE": lambda: maybe_convert_int( + os.getenv("VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE") + ), + "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE": lambda: ( + None + if os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE") is None + else os.getenv("VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE", "").lower() + in ("1", "true", "yes", "on") + ), # DeepGemm JITs the kernels on-demand. The warmup attempts to make DeepGemm # JIT all the required kernels before model execution so there is no # JIT'ing in the hot-path. However, this warmup increases the engine @@ -1711,6 +1746,12 @@ def _get_or_set_default() -> str: "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Optional override for the DeepGEMM MegaMoE fused expert kernel in + # DeepSeek V4. If unset, kernel_config.moe_backend decides; set to 1/0 to + # force-enable or force-disable this path during bring-up. + "VLLM_DEEPSEEK_V4_USE_MEGA_MOE": lambda: bool( + int(os.getenv("VLLM_DEEPSEEK_V4_USE_MEGA_MOE", "0")) + ), # Log model inspection after loading. # If enabled, logs a transformers-style hierarchical view of the model # with quantization methods and attention backends. diff --git a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py index 618084029159..6baedd3bbcbc 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/cutlass.py @@ -7,6 +7,9 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + _upcast_e8m0_to_fp32, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) @@ -26,6 +29,20 @@ ) +def _is_sm12x_compute_capability(compute_capability) -> bool: + if compute_capability is None: + return current_platform.is_device_capability_family(120) + + if isinstance(compute_capability, tuple): + return compute_capability[0] == 12 + + to_int = getattr(compute_capability, "to_int", None) + if callable(to_int): + return to_int() // 10 == 12 + + return int(compute_capability) // 10 == 12 + + class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel): @classmethod def is_supported( @@ -196,6 +213,9 @@ def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None: @classmethod def is_supported(cls, compute_capability=None): + if _is_sm12x_compute_capability(compute_capability): + return False, "CUTLASS block-scaled FP8 GEMM is not supported on SM12x." + if not CUTLASS_BLOCK_FP8_SUPPORTED: return ( False, @@ -219,6 +239,31 @@ def can_implement(cls, config: FP8ScaledMMLinearLayerConfig): ) return True, None + def process_weights_after_loading(self, layer: torch.nn.Module): + super().process_weights_after_loading(layer) + params = self._get_layer_params(layer) + weight_scale = ( + params.weight_scale + if params.weight_scale_inv is None + else params.weight_scale_inv + ) + scale_attr_name = ( + params.WEIGHT_SCALE + if params.weight_scale_inv is None + else params.WEIGHT_SCALE_INV + ) + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and weight_scale is not None + and weight_scale.dtype == e8m0_dtype + ): + replace_parameter( + layer, + scale_attr_name, + _upcast_e8m0_to_fp32(weight_scale), + ) + def apply_block_scaled_mm( self, A: torch.Tensor, diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 494d61338084..ba520ce476a0 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -18,15 +18,21 @@ ReplicatedLinear, ) from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer +from vllm.platforms import current_platform from vllm.utils.deep_gemm import fp8_einsum from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.ops.deepseek_v4_ops import ( combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, + sparse_prefill_combined_topk_size, +) +from vllm.v1.attention.ops.deepseek_v4_ops.fp8_einsum import ( + deepseek_v4_sm12x_fp8_einsum, ) from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( rocm_forward_decode_fallback, @@ -44,7 +50,10 @@ VllmConfig, get_current_vllm_config, ) -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger from vllm.model_executor.custom_op import PluggableLayer @@ -58,7 +67,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, ) -from vllm.platforms import current_platform from vllm.utils.multi_stream_utils import ( execute_in_parallel, maybe_execute_in_parallel, @@ -73,6 +81,26 @@ DeepseekV4IndexerBackend, get_max_prefill_buffer_size, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + disable_triton_sparse_mla_cudagraphs_if_enabled, + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_matmul_decode_enabled, + triton_sparse_mla_query_chunk_size, + triton_sparse_mla_topk_chunk_size, +) +from vllm.v1.attention.backends.mla.sparse_mla_kernels import ( + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead, + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead, + accumulate_indexed_sparse_mla_attention_chunk, + build_combined_sparse_mla_decode_valid_mask, + finish_sparse_mla_attention_with_sink, + finish_two_sparse_mla_attention_states_with_sink, + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead, + fp8ds_paged_sparse_mla_attention_with_sink_multihead, + matmul_sparse_mla_attention_with_sink, + sparse_mla_decode_head_block_size, +) from vllm.v1.attention.backends.mla.sparse_swa import DeepseekV4SWACache from vllm.v1.attention.ops.flashmla import ( flash_mla_sparse_fwd, @@ -83,10 +111,88 @@ logger = init_logger(__name__) + +def _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu: torch.Tensor, + gather_lens_cpu: torch.Tensor, + compress_ratio: int, + swa_only: bool, +) -> tuple[int, int]: + if seq_lens_cpu.numel() == 0: + return 0, 0 + + max_gather_len = int(gather_lens_cpu.max().item()) + if swa_only: + return 0, max_gather_len + + compressed_region_size = int((seq_lens_cpu // compress_ratio).max().item()) + return compressed_region_size, compressed_region_size + max_gather_len + + +def _sparse_mla_prefill_gather_len_upper_bound( + *, + max_model_len: int, + max_num_batched_tokens: int, + window_size: int, +) -> tuple[int, int]: + max_query_chunk_tokens = max(1, min(max_model_len, max_num_batched_tokens)) + max_prefix_len = max(max_model_len - max_query_chunk_tokens, 0) + max_gather_len = max_query_chunk_tokens + min( + max_prefix_len, + max(window_size - 1, 0), + ) + return max_query_chunk_tokens, max_gather_len + + +def _deepseek_v4_fp8_einsum_config( + capability_major: int, +) -> tuple[tuple[int, int, int], bool]: + if capability_major == 10: + return (1, 1, 128), True + return (1, 128, 128), False + + +def _use_deepseek_v4_sm12x_triton_fp8_einsum( + equation: str, + recipe: list[int], + b_scale: torch.Tensor, +) -> bool: + capability = current_platform.get_device_capability() + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + return ( + capability is not None + and capability.major == 12 + and equation == "bhr,hdr->bhd" + and tuple(recipe) == (1, 128, 128) + and b_scale.dtype in (torch.float32, e8m0_dtype) + ) + + +def _allocate_deepseek_v4_wo_a_output( + num_tokens: int, + num_groups: int, + output_rank: int, + dtype: torch.dtype, + device: torch.device, +) -> torch.Tensor: + shape = (num_tokens, num_groups, output_rank) + if torch.compiler.is_compiling(): + # Workspace growth can call torch.accelerator.empty_cache(), which + # Dynamo intentionally refuses to trace. During compilation this is a + # normal graph allocation, matching the o_padded allocation above. + return torch.empty(shape, dtype=dtype, device=device) + + (output,) = current_workspace_manager().get_simultaneous( + (shape, dtype), + ) + return output + + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). PREFILL_CHUNK_SIZE = 4 +_DEFAULT_SPARSE_MLA_TOPK_TOKENS = 2048 @dataclass @@ -172,6 +278,8 @@ def __init__( self.compress_ratio = compress_ratio if compress_ratio is not None else 1 self.prefix = prefix + disable_triton_sparse_mla_cudagraphs_if_enabled(mla_modules.vllm_config) + # Extract config from vllm_config config = mla_modules.vllm_config.model_config.hf_config tp_size = get_tensor_model_parallel_world_size() @@ -202,12 +310,13 @@ def __init__( self.wo_b = mla_modules.wo_b # Pick fp8_einsum recipe based on GPU arch: - # SM90: FP32 block scales stay [g, r/128, d/128] → sfb_gran_mn=128 - # SM100: INT32 packed scales become [g, r, ...] → sfb_gran_mn=1 + # SM90/SM120: FP32 block scales stay [g, r/128, d/128]. + # SM100: INT32 packed scales become [g, r, ...]. cap = current_platform.get_device_capability() assert cap is not None, "DeepseekV4 attention requires a CUDA device" - self._einsum_recipe = (1, 128, 128) if cap.major <= 9 else (1, 1, 128) - self._tma_aligned_scales = cap.major >= 10 + self._einsum_recipe, self._tma_aligned_scales = _deepseek_v4_fp8_einsum_config( + cap.major + ) self.rotary_emb = mla_modules.rotary_emb self.indexer_rotary_emb = mla_modules.indexer_rotary_emb @@ -336,10 +445,12 @@ def forward( wo_a_fp8 = self.wo_a.weight wo_a_scale = self.wo_a.weight_scale_inv - z = torch.empty( - (num_tokens, self.n_local_groups, self.o_lora_rank), - device=o.device, - dtype=torch.bfloat16, + z = _allocate_deepseek_v4_wo_a_output( + num_tokens, + self.n_local_groups, + self.o_lora_rank, + torch.bfloat16, + hidden_states.device, ) torch.ops.vllm.deepseek_v4_fp8_einsum( o_fp8, @@ -494,21 +605,8 @@ def wq_b_kv_insert() -> torch.Tensor: # Handle dummy run (no metadata). if not isinstance(attn_metadata, dict): - # Reserve _forward_prefill's bf16-gather workspace; the dummy - # run returns before mla_attn runs, so without this the shared - # workspace locks below the real prefill size. - sub = self.mla_attn - swa_only = sub.compress_ratio <= 1 - N = ( - 0 - if swa_only - else (sub.max_model_len + sub.compress_ratio - 1) // sub.compress_ratio - ) - M = N + sub.window_size + sub.max_num_batched_tokens - current_workspace_manager().get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - ) out.zero_() + self.mla_attn._reserve_prefill_workspace() return # Pad q to FlashMLA-required head count (64 or 128) @@ -594,6 +692,68 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: + if equation == "bhr,hdr->bhd" and b.dim() == 2: + num_groups = out.shape[1] + out_rank = out.shape[2] + hidden_size = a.shape[2] + if b.shape[0] % out_rank != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight rows must be divisible by " + f"out_rank={out_rank}, got {b.shape[0]}" + ) + b_groups = b.shape[0] // out_rank + group_start = 0 + if b_groups != num_groups: + if b_groups % num_groups != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum weight groups must match the " + "TP-local output groups or be an integer multiple of " + f"them, got weight_groups={b_groups}, " + f"output_groups={num_groups}" + ) + group_partitions = b_groups // num_groups + group_start = ( + get_tensor_model_parallel_rank() % group_partitions + ) * num_groups + b = b.view(b_groups, out_rank, hidden_size) + if group_start != 0 or b_groups != num_groups: + b = b.narrow(0, group_start, num_groups) + + if b_scale.dim() == 2: + scale_mn = recipe[1] + scale_k_pack = 4 if b_scale.dtype == torch.int32 else 1 + scale_k = recipe[2] * scale_k_pack + scale_out_blocks = (out_rank + scale_mn - 1) // scale_mn + scale_hidden_blocks = (hidden_size + scale_k - 1) // scale_k + if b_scale.shape[0] % scale_out_blocks != 0: + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale rows must be divisible by " + f"scale_out_blocks={scale_out_blocks}, " + f"got {b_scale.shape[0]}" + ) + scale_groups = b_scale.shape[0] // scale_out_blocks + if scale_groups not in (num_groups, b_groups): + raise RuntimeError( + "DeepSeek V4 fp8 einsum scale groups must match the " + "TP-local output groups or weight groups, got " + f"scale_groups={scale_groups}, output_groups={num_groups}, " + f"weight_groups={b_groups}" + ) + b_scale = b_scale.view( + scale_groups, + scale_out_blocks, + scale_hidden_blocks, + ) + if scale_groups == b_groups and scale_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + elif b_scale.dim() == 3 and b_scale.shape[0] == b_groups: + if b_groups != num_groups: + b_scale = b_scale.narrow(0, group_start, num_groups) + + if _use_deepseek_v4_sm12x_triton_fp8_einsum(equation, recipe, b_scale): + deepseek_v4_sm12x_fp8_einsum(a, a_scale, b, b_scale, out) + return + fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) @@ -711,7 +871,11 @@ def __init__( assert cache_config is not None cache_config.cache_dtype = "fp8_ds_mla" kv_cache_dtype = "fp8_ds_mla" - logger.info_once("Using DeepSeek's fp8_ds_mla KV cache format.") + logger.info_once( + "Using DeepSeek's fp8_ds_mla KV cache format. To use standard " + "fp8 kv-cache format, please set `--attention-backend " + "FLASHINFER_MLA_SPARSE`" + ) self.kv_cache_dtype = kv_cache_dtype @@ -724,6 +888,73 @@ def __init__( self.kv_cache = torch.tensor([]) + def _prefill_workspace_topk_bound(self) -> int: + if self.compress_ratio <= 1: + return 0 + if ( + self.topk_indices_buffer is not None + and self.topk_indices_buffer.ndim > 0 + and self.topk_indices_buffer.shape[-1] > 0 + ): + return int(self.topk_indices_buffer.shape[-1]) + indexer_topk = getattr(self.indexer, "topk_tokens", None) + if indexer_topk is not None: + return int(indexer_topk) + return _DEFAULT_SPARSE_MLA_TOPK_TOKENS + + def _prefill_workspace_reservation_specs( + self, + ) -> tuple[tuple[tuple[int, ...], torch.dtype], ...]: + max_model_len = max(1, int(self.max_model_len)) + max_num_batched_tokens = max(1, int(self.max_num_batched_tokens)) + window_size = max(1, int(self.window_size)) + compress_ratio = max(1, int(self.compress_ratio)) + head_dim = int(self.head_dim) + num_heads = int(self.num_heads) + + max_query_chunk_tokens, max_gather_len = ( + _sparse_mla_prefill_gather_len_upper_bound( + max_model_len=max_model_len, + max_num_batched_tokens=max_num_batched_tokens, + window_size=window_size, + ) + ) + if compress_ratio <= 1: + m_bound = max_gather_len + else: + compressed_region_size = max_model_len // compress_ratio + m_bound = compressed_region_size + max_gather_len + + combined_topk = sparse_prefill_combined_topk_size( + DeepseekV4MLAAttention._prefill_workspace_topk_bound(self), + window_size, + ) + specs: list[tuple[tuple[int, ...], torch.dtype]] = [ + ((PREFILL_CHUNK_SIZE, m_bound, head_dim), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ] + if is_triton_sparse_mla_enabled_for_platform(): + query_chunk_size = min( + max_query_chunk_tokens, + triton_sparse_mla_query_chunk_size(), + ) + specs.extend( + [ + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads), torch.float32), + ((query_chunk_size, num_heads, head_dim), torch.float32), + ] + ) + return tuple(specs) + + def _reserve_prefill_workspace(self) -> None: + try: + workspace_manager = current_workspace_manager() + except AssertionError: + return + workspace_manager.get_simultaneous(*self._prefill_workspace_reservation_specs()) + def get_attn_backend(self) -> type[AttentionBackend]: return DeepseekV4FlashMLASparseBackend @@ -743,6 +974,332 @@ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None: model_version="deepseek_v4", ) + def _forward_sparse_mla_swa_decode_triton( + self, + q: torch.Tensor, + swa_k_cache: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + output: torch.Tensor, + ) -> None: + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if not mtp_decode: + fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=self.num_heads, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + return + + ( + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_sparse_mla_attention_with_sink( + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_compressed_decode_triton( + self, + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + swa_k_cache: torch.Tensor, + topk_indices: torch.Tensor, + topk_lens: torch.Tensor, + swa_metadata: "DeepseekSparseSWAMetadata", + attn_metadata: FlashMLASparseMetadata, + output: torch.Tensor, + ) -> None: + if self.compress_ratio not in (4, 128): + raise NotImplementedError( + "Triton sparse MLA compressed decode currently supports " + f"compress_ratio=4 or 128, got {self.compress_ratio}" + ) + + num_decodes = swa_metadata.num_decodes + num_decode_tokens = swa_metadata.num_decode_tokens + mtp_decode = num_decode_tokens != num_decodes + + max_swa_len = swa_metadata.decode_swa_indices.shape[-1] + compressed_block_size = attn_metadata.block_size // self.compress_ratio + compressed_topk = topk_indices.shape[-1] + topk_chunk_size = min( + compressed_topk, + triton_sparse_mla_topk_chunk_size(), + ) + compressed_slot_ids = topk_indices[:, 0, :] + swa_lens = swa_metadata.decode_swa_lens[:num_decode_tokens] + swa_indices = swa_metadata.decode_swa_indices[:num_decode_tokens] + head_block_size = sparse_mla_decode_head_block_size(num_decode_tokens) + if ( + not mtp_decode + and compressed_topk <= topk_chunk_size + and triton_sparse_mla_matmul_decode_enabled() + ): + total_candidates = compressed_topk + max_swa_len + ( + combined_kv, + valid_tokens, + score_buffer, + ) = current_workspace_manager().get_simultaneous( + ( + (num_decode_tokens, total_candidates, q.shape[-1]), + torch.bfloat16, + ), + ((num_decode_tokens, total_candidates), torch.bool), + ((num_decode_tokens, self.num_heads, total_candidates), torch.bfloat16), + ) + dequantize_combined_sparse_mla_decode_kv( + combined_kv, + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + swa_k_cache, + swa_metadata.seq_lens[:num_decodes], + swa_lens, + swa_metadata.block_table[:num_decodes], + swa_metadata.block_size, + ) + + build_combined_sparse_mla_decode_valid_mask( + valid_tokens, + compressed_slot_ids, + topk_lens, + swa_lens, + ) + use_dot_finish = num_decode_tokens <= 16 + matmul_sparse_mla_attention_with_sink( + q=q, + kv=combined_kv, + valid_tokens=valid_tokens, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + num_heads=self.num_heads, + score_buffer=score_buffer, + value_block_size=512 if use_dot_finish else 256, + candidate_block_size=128 if use_dot_finish else None, + ) + return + + if not mtp_decode and compressed_topk <= topk_chunk_size: + fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q=q, + compressed_k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids, + topk_lens=topk_lens, + compressed_block_size=compressed_block_size, + swa_k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + swa_block_size=swa_metadata.block_size, + num_compressed_candidates=compressed_topk, + num_swa_candidates=max_swa_len, + scale=self.scale, + attn_sink=self.attn_sink, + output=output, + head_block_size=head_block_size, + num_heads=self.num_heads, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + return + + ( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + ) = current_workspace_manager().get_simultaneous( + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads), torch.float32), + ((num_decode_tokens, self.num_heads, q.shape[-1]), torch.float32), + ) + comp_max_score.fill_(float("-inf")) + comp_denom.zero_() + comp_acc.zero_() + swa_max_score.fill_(float("-inf")) + swa_denom.zero_() + swa_acc.zero_() + + for chunk_start in range(0, compressed_topk, topk_chunk_size): + chunk_end = min(chunk_start + topk_chunk_size, compressed_topk) + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=compressed_k_cache, + slot_ids=compressed_slot_ids[:, chunk_start:chunk_end], + lens=topk_lens, + block_size=compressed_block_size, + candidate_offset=chunk_start, + scale=self.scale, + max_score=comp_max_score, + denom=comp_denom, + acc=comp_acc, + head_block_size=head_block_size, + ) + if mtp_decode: + accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + slot_ids=swa_indices, + lens=swa_lens, + block_size=swa_metadata.block_size, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + else: + accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q=q, + k_cache=swa_k_cache, + seq_lens=swa_metadata.seq_lens[:num_decodes], + gather_lens=swa_lens, + block_table=swa_metadata.block_table[:num_decodes], + block_size=swa_metadata.block_size, + candidate_offset=0, + num_candidates=max_swa_len, + scale=self.scale, + max_score=swa_max_score, + denom=swa_denom, + acc=swa_acc, + head_block_size=head_block_size, + ) + finish_two_sparse_mla_attention_states_with_sink( + comp_max_score, + comp_denom, + comp_acc, + swa_max_score, + swa_denom, + swa_acc, + self.attn_sink, + output=output, + ) + if output.shape[1] > self.num_heads: + output[:, self.num_heads :].zero_() + + def _forward_sparse_mla_prefill_triton( + self, + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + output: torch.Tensor, + state_buffers: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, + ) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min( + combined_indices.shape[-1], + triton_sparse_mla_topk_chunk_size(), + ) + query_chunk_size = min( + q.shape[0], + triton_sparse_mla_query_chunk_size(), + ) + if state_buffers is None: + ( + max_score_buffer, + denom_buffer, + output_buffer, + ) = current_workspace_manager().get_simultaneous( + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + else: + max_score_buffer, denom_buffer, output_buffer = state_buffers + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + num_tokens = token_end - token_start + max_score = max_score_buffer[:num_tokens] + denom = denom_buffer[:num_tokens] + subset_acc = output_buffer[:num_tokens] + max_score.fill_(float("-inf")) + denom.zero_() + subset_acc.zero_() + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + accumulate_indexed_sparse_mla_attention_chunk( + q=q_chunk, + kv_flat=kv_flat, + indices=indices_chunk_full[:, index_start:index_end], + lens=lens_chunk, + candidate_offset=index_start, + scale=self.scale, + max_score=max_score, + denom=denom, + acc=subset_acc, + ) + + finish_sparse_mla_attention_with_sink( + max_score, + denom, + subset_acc, + self.attn_sink, + output=output[token_start:token_end], + ) + if output.shape[1] > self.num_heads: + output[token_start:token_end, self.num_heads :].zero_() + def forward( self, q: torch.Tensor, @@ -823,12 +1380,14 @@ def _forward_decode( if self.compress_ratio == 4: # C4A: local indices differ per layer (filled by Indexer). assert self.topk_indices_buffer is not None + local_topk_indices = self.topk_indices_buffer[:num_decode_tokens] global_indices, topk_lens = compute_global_topk_indices_and_lens( - self.topk_indices_buffer[:num_decode_tokens], + local_topk_indices, swa_metadata.token_to_req_indices, attn_metadata.block_table[:num_decodes], block_size, is_valid, + global_topk_indices=local_topk_indices, ) topk_indices = global_indices.view(num_decode_tokens, 1, -1) else: @@ -867,9 +1426,35 @@ def _forward_decode( # Use unsqueeze to preserve strides (handles padded blocks correctly) swa_cache = self.swa_cache_layer.kv_cache.unsqueeze(-2) # Reshape KV cache to (num_blocks, block_size, 1, head_bytes) + compressed_k_cache = kv_cache if kv_cache is not None: kv_cache = kv_cache.unsqueeze(-2) + if is_triton_sparse_mla_enabled(q.device): + if swa_only: + self._forward_sparse_mla_swa_decode_triton( + q=q, + swa_k_cache=self.swa_cache_layer.kv_cache, + swa_metadata=swa_metadata, + output=output, + ) + return + if self.compress_ratio in (4, 128): + assert compressed_k_cache is not None + assert attn_metadata is not None + assert topk_indices is not None + assert topk_lens is not None + self._forward_sparse_mla_compressed_decode_triton( + q=q, + compressed_k_cache=compressed_k_cache, + swa_k_cache=self.swa_cache_layer.kv_cache, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_metadata=swa_metadata, + attn_metadata=attn_metadata, + output=output, + ) + return # One FlashMLASchedMeta per layer type, shared across all same-type # layers within this decode step. The first forward call per type # triggers the in-kernel planner (allocating tile_scheduler_metadata @@ -932,8 +1517,12 @@ def _forward_prefill( # Use pre-computed prefill metadata. seq_lens = swa_metadata.prefill_seq_lens gather_lens = swa_metadata.prefill_gather_lens + seq_lens_cpu = swa_metadata.prefill_seq_lens_cpu + gather_lens_cpu = swa_metadata.prefill_gather_lens_cpu assert seq_lens is not None assert gather_lens is not None + assert seq_lens_cpu is not None + assert gather_lens_cpu is not None # Derive prefill-local token offsets from the full query_start_loc_cpu. query_start_loc_cpu = swa_metadata.query_start_loc_cpu @@ -952,24 +1541,69 @@ def _forward_prefill( assert attn_metadata is not None topk_indices = attn_metadata.c128a_prefill_topk_indices top_k = topk_indices.shape[-1] - # Compressed region must fit the full compressed pool (seq_len // - # compress_ratio), not just top_k. top_k bounds how many indices - # the indexer selects, not the pool size it indexes into. - N = (self.max_model_len + self.compress_ratio - 1) // self.compress_ratio else: # NOTE(woosuk): topk_indices will not be used for SWA-only layers. assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[num_decode_tokens:] top_k = 0 - N = 0 - M = N + self.window_size + self.max_num_batched_tokens + N, M = _sparse_mla_prefill_workspace_bounds( + seq_lens_cpu=seq_lens_cpu, + gather_lens_cpu=gather_lens_cpu, + compress_ratio=self.compress_ratio, + swa_only=swa_only, + ) num_chunks = (num_prefills + PREFILL_CHUNK_SIZE - 1) // PREFILL_CHUNK_SIZE + max_query_chunk_tokens = 0 + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * PREFILL_CHUNK_SIZE + chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) + query_start = ( + query_start_loc_cpu[num_decodes + chunk_start] - prefill_token_base + ) + query_end = ( + query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base + ) + max_query_chunk_tokens = max( + max_query_chunk_tokens, int(query_end - query_start) + ) + combined_topk = sparse_prefill_combined_topk_size(top_k, self.window_size) workspace_manager = current_workspace_manager() - kv = workspace_manager.get_simultaneous( - ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), - )[0] + triton_sparse_mla_enabled = is_triton_sparse_mla_enabled(q.device) + if triton_sparse_mla_enabled: + query_chunk_size = min(q.shape[0], triton_sparse_mla_query_chunk_size()) + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + max_score_buffer, + denom_buffer, + output_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads), torch.float32), + ((query_chunk_size, self.num_heads, q.shape[-1]), torch.float32), + ) + prefill_state_buffers = ( + max_score_buffer, + denom_buffer, + output_buffer, + ) + else: + ( + kv, + combined_indices_buffer, + combined_lens_buffer, + ) = workspace_manager.get_simultaneous( + ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), + ((max_query_chunk_tokens, combined_topk), torch.int32), + ((max_query_chunk_tokens,), torch.int32), + ) + prefill_state_buffers = None for chunk_idx in range(num_chunks): chunk_start = chunk_idx * PREFILL_CHUNK_SIZE chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) @@ -1008,6 +1642,7 @@ def _forward_prefill( query_start_loc_cpu[num_decodes + chunk_end] - prefill_token_base ) + query_tokens = query_end - query_start combined_indices, combined_lens = combine_topk_swa_indices( topk_indices[query_start:query_end], query_start_loc[ @@ -1020,8 +1655,21 @@ def _forward_prefill( top_k, M, N, + combined_indices=combined_indices_buffer[:query_tokens], + combined_lens=combined_lens_buffer[:query_tokens], ) + if triton_sparse_mla_enabled: + self._forward_sparse_mla_prefill_triton( + q=q[query_start:query_end], + kv=kv[:chunk_size], + combined_indices=combined_indices, + combined_lens=combined_lens, + output=output[query_start:query_end], + state_buffers=prefill_state_buffers, + ) + continue + if current_platform.is_rocm(): rocm_sparse_attn_prefill( q=q[query_start:query_end], @@ -1033,16 +1681,17 @@ def _forward_prefill( attn_sink=self.attn_sink, output=output[query_start:query_end], ) - else: - output_chunk, _, _ = flash_mla_sparse_fwd( - q=q[query_start:query_end], - kv=kv.view(-1, 1, q.shape[-1]), - indices=combined_indices.unsqueeze(1), - sm_scale=self.scale, - attn_sink=self.attn_sink, - topk_length=combined_lens, - out=output[query_start:query_end], - ) + continue + + output_chunk, _, _ = flash_mla_sparse_fwd( + q=q[query_start:query_end], + kv=kv.view(-1, 1, q.shape[-1]), + indices=combined_indices.unsqueeze(1), + sm_scale=self.scale, + attn_sink=self.attn_sink, + topk_length=combined_lens, + out=output[query_start:query_end], + ) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): diff --git a/vllm/model_executor/layers/deepseek_v4_triton_kernels.py b/vllm/model_executor/layers/deepseek_v4_triton_kernels.py new file mode 100644 index 000000000000..b5048c5fd013 --- /dev/null +++ b/vllm/model_executor/layers/deepseek_v4_triton_kernels.py @@ -0,0 +1,1282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Triton fallback kernels used by the local DeepSeek V4 path.""" + +import torch + +from vllm.triton_utils import LOG2E, tl, triton + +DEEPSEEK_V4_MLA_HEAD_DIM = 512 +FP8_DS_MLA_FP8_DIM = 448 +FP8_DS_MLA_SCALE_GROUP = 64 +FP8_DS_MLA_SCALE_BYTES = 8 +FP8_DS_MLA_TOKEN_BYTES = 576 + + +def _view_packed_fp8_paged_mqa_kv_cache( + kv_cache: torch.Tensor, + head_dim: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return FP8 values and fp32 scales from indexer cache block storage.""" + if kv_cache.dtype != torch.uint8: + raise TypeError(f"Expected uint8 kv_cache, got {kv_cache.dtype}") + if kv_cache.dim() == 3: + num_blocks, block_size, head_dim_with_scale = kv_cache.shape + num_kv_heads = 1 + elif kv_cache.dim() == 4: + num_blocks, block_size, num_kv_heads, head_dim_with_scale = kv_cache.shape + else: + raise ValueError( + f"Expected 3D or 4D kv_cache, got {kv_cache.dim()} dimensions" + ) + if num_kv_heads != 1: + raise ValueError(f"Expected one KV head, got {num_kv_heads}") + + scale_bytes = head_dim_with_scale - head_dim + if scale_bytes <= 0 or scale_bytes % torch.float32.itemsize != 0: + raise ValueError( + "Expected kv_cache last dimension to contain FP8 values followed " + f"by fp32 scale bytes; got head_dim={head_dim}, " + f"last_dim={head_dim_with_scale}" + ) + + block_stride = kv_cache.stride(0) + base_storage_offset = kv_cache.storage_offset() + scale_elems = scale_bytes // torch.float32.itemsize + kv_values = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, head_dim), + stride=(block_stride, head_dim, head_dim, 1), + storage_offset=base_storage_offset, + ).view(torch.float8_e4m3fn) + kv_scale = torch.as_strided( + kv_cache, + size=(num_blocks, block_size, 1, scale_bytes), + stride=(block_stride, scale_bytes, scale_bytes, 1), + storage_offset=base_storage_offset + block_size * head_dim, + ).view(torch.float32) + return kv_values, kv_scale[..., :scale_elems] + + +@triton.jit +def _sparse_attention_bf16_kernel( + q_ptr, + kv_ptr, + indices_ptr, + lengths_ptr, + sink_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_heads: tl.constexpr, + seq_kv: tl.constexpr, + index_topk: tl.constexpr, + sm_scale_log2: tl.constexpr, + stride_qt: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_k: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + HAS_SINK: tl.constexpr, + LOG2E_CONST: tl.constexpr, +): + token_id = tl.program_id(0) + head_block = tl.program_id(1) + heads = head_block * BLOCK_H + tl.arange(0, BLOCK_H) + offs_d = tl.arange(0, BLOCK_D) + mask_h = heads < num_heads + + q = tl.load( + q_ptr + + token_id * stride_qt + + heads[:, None] * stride_qh + + offs_d[None, :] * stride_qd, + mask=mask_h[:, None], + other=0.0, + ) + + if HAS_SINK: + sink = tl.load(sink_ptr + heads, mask=mask_h, other=-float("inf")) + e_max = sink * LOG2E_CONST + e_sum = tl.where(mask_h, 1.0, 0.0) + else: + e_max = tl.full((BLOCK_H,), -float("inf"), dtype=tl.float32) + e_sum = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + length = tl.load(lengths_ptr + token_id) + for start in range(0, index_topk, BLOCK_N): + offs_n = start + tl.arange(0, BLOCK_N) + idx = tl.load( + indices_ptr + token_id * stride_indices_t + offs_n * stride_indices_k, + mask=offs_n < index_topk, + other=-1, + ) + mask_kv = (offs_n < length) & (idx >= 0) & (idx < seq_kv) + k = tl.load( + kv_ptr + idx[None, :] * stride_kv_t + offs_d[:, None] * stride_kv_d, + mask=mask_kv[None, :], + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) * sm_scale_log2 + qk = tl.where( + mask_h[:, None] & mask_kv[None, :], + qk, + -3.4028234663852886e38, + ) + + v = tl.load( + kv_ptr + idx[:, None] * stride_kv_t + offs_d[None, :] * stride_kv_d, + mask=mask_kv[:, None], + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp2(e_max - n_e_max) + p = tl.exp2(qk - n_e_max[:, None]) + p = tl.where(mask_h[:, None] & mask_kv[None, :], p, 0.0) + acc = acc * re_scale[:, None] + tl.dot(p.to(v.dtype), v) + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + acc = acc / tl.maximum(e_sum, 1.0e-20)[:, None] + tl.store( + out_ptr + + token_id * stride_out_t + + heads[:, None] * stride_out_h + + offs_d[None, :] * stride_out_d, + acc.to(tl.bfloat16), + mask=mask_h[:, None], + ) + + +def sparse_attention_triton( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + lengths: torch.Tensor, + scale: float, + attn_sink: torch.Tensor | None, + out: torch.Tensor, +) -> None: + if indices.ndim == 3: + indices = indices.squeeze(1) + if kv.ndim == 3: + kv = kv.squeeze(1) + + num_tokens, num_heads, head_dim = q.shape + if num_tokens == 0: + return + if head_dim != DEEPSEEK_V4_MLA_HEAD_DIM: + raise ValueError( + "DeepSeek V4 sparse Triton fallback expects " + f"D={DEEPSEEK_V4_MLA_HEAD_DIM}, got {head_dim}" + ) + assert kv.shape[-1] == head_dim + assert out.shape[-1] == head_dim + + grid = (num_tokens, triton.cdiv(num_heads, 8)) + _sparse_attention_bf16_kernel[grid]( + q, + kv, + indices, + lengths, + attn_sink if attn_sink is not None else q, + out, + num_tokens, + num_heads, + kv.shape[0], + indices.shape[-1], + scale * LOG2E, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + indices.stride(0), + indices.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_H=8, + BLOCK_N=16, + BLOCK_D=DEEPSEEK_V4_MLA_HEAD_DIM, + HAS_SINK=attn_sink is not None, + LOG2E_CONST=LOG2E, + num_warps=8, + ) + + +@triton.jit +def _decode_sparse_attention_fp8_kernel( + q_ptr, + swa_cache_fp8_ptr, + swa_cache_bf16_ptr, + swa_cache_u8_ptr, + swa_indices_ptr, + swa_lens_ptr, + extra_cache_fp8_ptr, + extra_cache_bf16_ptr, + extra_cache_u8_ptr, + extra_indices_ptr, + extra_lens_ptr, + sink_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_heads: tl.constexpr, + swa_index_topk: tl.constexpr, + extra_index_topk: tl.constexpr, + swa_num_blocks: tl.constexpr, + extra_num_blocks: tl.constexpr, + swa_block_size: tl.constexpr, + extra_block_size: tl.constexpr, + swa_stride_block_bytes: tl.constexpr, + extra_stride_block_bytes: tl.constexpr, + sm_scale_log2: tl.constexpr, + stride_qt: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_swa_indices_t: tl.constexpr, + stride_swa_indices_k: tl.constexpr, + stride_extra_indices_t: tl.constexpr, + stride_extra_indices_k: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + FP8_DIM: tl.constexpr, + SCALE_GROUP: tl.constexpr, + SCALE_BYTES: tl.constexpr, + TOKEN_BYTES: tl.constexpr, + HAS_EXTRA: tl.constexpr, + HAS_SINK: tl.constexpr, + LOG2E_CONST: tl.constexpr, +): + token_id = tl.program_id(0) + head_block = tl.program_id(1) + heads = head_block * BLOCK_H + tl.arange(0, BLOCK_H) + offs_d = tl.arange(0, BLOCK_D) + mask_h = heads < num_heads + + q = tl.load( + q_ptr + + token_id * stride_qt + + heads[:, None] * stride_qh + + offs_d[None, :] * stride_qd, + mask=mask_h[:, None], + other=0.0, + ) + + if HAS_SINK: + sink = tl.load(sink_ptr + heads, mask=mask_h, other=-float("inf")) + e_max = sink * LOG2E_CONST + e_sum = tl.where(mask_h, 1.0, 0.0) + else: + e_max = tl.full((BLOCK_H,), -float("inf"), dtype=tl.float32) + e_sum = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + + swa_len = tl.load(swa_lens_ptr + token_id) + extra_len = tl.load(extra_lens_ptr + token_id) if HAS_EXTRA else 0 + total_len = extra_len + swa_len + + for start in range(0, extra_index_topk + swa_index_topk, BLOCK_N): + offs_n = start + tl.arange(0, BLOCK_N) + use_extra = HAS_EXTRA & (offs_n < extra_len) + use_swa = (offs_n >= extra_len) & (offs_n < total_len) + + extra_cols = offs_n + swa_cols = offs_n - extra_len + extra_idx = tl.load( + extra_indices_ptr + + token_id * stride_extra_indices_t + + extra_cols * stride_extra_indices_k, + mask=HAS_EXTRA & (extra_cols < extra_index_topk), + other=-1, + ) + swa_idx = tl.load( + swa_indices_ptr + + token_id * stride_swa_indices_t + + swa_cols * stride_swa_indices_k, + mask=(swa_cols >= 0) & (swa_cols < swa_index_topk), + other=-1, + ) + idx = tl.where(use_extra, extra_idx, swa_idx) + + extra_block = idx // extra_block_size + extra_pos = idx - extra_block * extra_block_size + swa_block = idx // swa_block_size + swa_pos = idx - swa_block * swa_block_size + valid_extra = use_extra & (idx >= 0) & (extra_block < extra_num_blocks) + valid_swa = use_swa & (idx >= 0) & (swa_block < swa_num_blocks) + valid = valid_extra | valid_swa + + extra_token_base = extra_block * extra_stride_block_bytes + extra_token_base += extra_pos * TOKEN_BYTES + swa_token_base = swa_block * swa_stride_block_bytes + swa_token_base += swa_pos * TOKEN_BYTES + token_base = tl.where(use_extra, extra_token_base, swa_token_base) + block_size = tl.where(use_extra, extra_block_size, swa_block_size) + stride_block_bytes = tl.where( + use_extra, extra_stride_block_bytes, swa_stride_block_bytes + ) + pos = tl.where(use_extra, extra_pos, swa_pos) + + is_fp8 = offs_d < FP8_DIM + scale_offsets = ( + tl.where(use_extra, extra_block, swa_block)[:, None] + * stride_block_bytes[:, None] + + block_size[:, None] * TOKEN_BYTES + + pos[:, None] * SCALE_BYTES + + (offs_d[None, :] // SCALE_GROUP) + ) + encoded_scale = tl.load( + tl.where(use_extra[:, None], extra_cache_u8_ptr, swa_cache_u8_ptr) + + scale_offsets, + mask=valid[:, None] & is_fp8[None, :], + other=127, + ).to(tl.float32) + fp8_scale = tl.exp2(encoded_scale - 127.0) + + fp8_offsets = token_base[:, None] + offs_d[None, :] + fp8_vals = ( + tl.load( + tl.where(use_extra[:, None], extra_cache_fp8_ptr, swa_cache_fp8_ptr) + + fp8_offsets, + mask=valid[:, None] & is_fp8[None, :], + other=0.0, + ).to(tl.float32) + * fp8_scale + ) + + bf16_offsets = (token_base[:, None] + FP8_DIM) // 2 + bf16_offsets += offs_d[None, :] - FP8_DIM + bf16_vals = tl.load( + tl.where(use_extra[:, None], extra_cache_bf16_ptr, swa_cache_bf16_ptr) + + bf16_offsets, + mask=valid[:, None] & (~is_fp8[None, :]), + other=0.0, + ).to(tl.float32) + k = tl.where(is_fp8[None, :], fp8_vals, bf16_vals) + + qk = tl.dot(q, tl.trans(k.to(q.dtype))) * sm_scale_log2 + qk = tl.where( + mask_h[:, None] & valid[None, :], + qk, + -3.4028234663852886e38, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp2(e_max - n_e_max) + p = tl.exp2(qk - n_e_max[:, None]) + p = tl.where(mask_h[:, None] & valid[None, :], p, 0.0) + acc = acc * re_scale[:, None] + tl.dot(p.to(k.dtype), k) + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + acc = acc / tl.maximum(e_sum, 1.0e-20)[:, None] + tl.store( + out_ptr + + token_id * stride_out_t + + heads[:, None] * stride_out_h + + offs_d[None, :] * stride_out_d, + acc.to(tl.bfloat16), + mask=mask_h[:, None], + ) + + +def decode_sparse_attention_triton( + q: torch.Tensor, + swa_cache: torch.Tensor, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor | None, + out: torch.Tensor, + extra_cache: torch.Tensor | None = None, + extra_indices: torch.Tensor | None = None, + extra_lens: torch.Tensor | None = None, +) -> None: + if swa_indices.ndim == 3: + swa_indices = swa_indices.squeeze(1) + if extra_indices is not None and extra_indices.ndim == 3: + extra_indices = extra_indices.squeeze(1) + + num_tokens, num_heads, head_dim = q.shape + if num_tokens == 0: + return + if head_dim != DEEPSEEK_V4_MLA_HEAD_DIM: + raise ValueError( + "DeepSeek V4 decode Triton fallback expects " + f"D={DEEPSEEK_V4_MLA_HEAD_DIM}, got {head_dim}" + ) + has_extra = ( + extra_cache is not None and extra_indices is not None and extra_lens is not None + ) + if not has_extra: + extra_cache = swa_cache + extra_indices = swa_indices[:, :1] + extra_lens = swa_lens + + assert extra_cache is not None + assert extra_indices is not None + assert extra_lens is not None + grid = (num_tokens, triton.cdiv(num_heads, 8)) + _decode_sparse_attention_fp8_kernel[grid]( + q, + swa_cache.view(torch.float8_e4m3fn), + swa_cache.view(torch.bfloat16), + swa_cache, + swa_indices, + swa_lens, + extra_cache.view(torch.float8_e4m3fn), + extra_cache.view(torch.bfloat16), + extra_cache, + extra_indices, + extra_lens, + attn_sink if attn_sink is not None else q, + out, + num_tokens, + num_heads, + swa_indices.shape[-1], + extra_indices.shape[-1] if has_extra else 0, + swa_cache.shape[0], + extra_cache.shape[0], + swa_cache.shape[1], + extra_cache.shape[1], + swa_cache.stride(0), + extra_cache.stride(0), + scale * LOG2E, + q.stride(0), + q.stride(1), + q.stride(2), + swa_indices.stride(0), + swa_indices.stride(1), + extra_indices.stride(0), + extra_indices.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_H=8, + BLOCK_N=16, + BLOCK_D=DEEPSEEK_V4_MLA_HEAD_DIM, + FP8_DIM=FP8_DS_MLA_FP8_DIM, + SCALE_GROUP=FP8_DS_MLA_SCALE_GROUP, + SCALE_BYTES=FP8_DS_MLA_SCALE_BYTES, + TOKEN_BYTES=FP8_DS_MLA_TOKEN_BYTES, + HAS_EXTRA=has_extra, + HAS_SINK=attn_sink is not None, + LOG2E_CONST=LOG2E, + num_warps=8, + ) + + +@triton.jit +def _deepseek_v4_fp8_einsum_triton_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + B: tl.constexpr, + G: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + a_stride_b: tl.constexpr, + a_stride_g: tl.constexpr, + a_stride_k: tl.constexpr, + as_stride_b: tl.constexpr, + as_stride_g: tl.constexpr, + as_stride_kb: tl.constexpr, + b_stride_g: tl.constexpr, + b_stride_n: tl.constexpr, + b_stride_k: tl.constexpr, + bs_stride_g: tl.constexpr, + bs_stride_nb: tl.constexpr, + bs_stride_kb: tl.constexpr, + out_stride_b: tl.constexpr, + out_stride_g: tl.constexpr, + out_stride_n: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_g = tl.program_id(1) + pid_n = tl.program_id(2) + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + acc = tl.zeros((BLOCK_B, BLOCK_N), dtype=tl.float32) + for k0 in range(0, K, BLOCK_K): + k = k0 + offs_k + kb = k0 // BLOCK_K + + a = tl.load( + a_ptr + + offs_b[:, None] * a_stride_b + + pid_g * a_stride_g + + k[None, :] * a_stride_k, + mask=(offs_b[:, None] < B) & (k[None, :] < K), + other=0.0, + ) + b = tl.load( + b_ptr + + pid_g * b_stride_g + + offs_n[:, None] * b_stride_n + + k[None, :] * b_stride_k, + mask=(offs_n[:, None] < N) & (k[None, :] < K), + other=0.0, + ) + a_s = tl.load( + a_scale_ptr + + offs_b * as_stride_b + + pid_g * as_stride_g + + kb * as_stride_kb, + mask=offs_b < B, + other=0.0, + ).to(tl.float32) + b_s = tl.load( + b_scale_ptr + + pid_g * bs_stride_g + + (offs_n // BLOCK_K) * bs_stride_nb + + kb * bs_stride_kb, + mask=offs_n < N, + other=0.0, + ).to(tl.float32) + acc += ( + tl.dot(a, tl.trans(b), out_dtype=tl.float32) * a_s[:, None] * b_s[None, :] + ) + + tl.store( + out_ptr + + offs_b[:, None] * out_stride_b + + pid_g * out_stride_g + + offs_n[None, :] * out_stride_n, + acc, + mask=(offs_b[:, None] < B) & (offs_n[None, :] < N), + ) + + +def _e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + return (scale.view(torch.uint8).to(torch.int32) << 23).view(torch.float32) + + +def _unpack_int32_e8m0_scales( + packed_scale: torch.Tensor, + num_blocks: int, +) -> torch.Tensor: + shifts = torch.arange(4, device=packed_scale.device, dtype=torch.int32) * 8 + unpacked = (packed_scale.to(torch.int32).unsqueeze(-1) >> shifts) & 0xFF + unpacked = unpacked.reshape(*packed_scale.shape[:-1], -1)[..., :num_blocks] + return (unpacked << 23).view(torch.float32) + + +def _normalize_deepseek_v4_fp8_einsum_inputs( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, G, K = a.shape + _, out_g, N = out.shape + assert out_g == G + k_blocks = triton.cdiv(K, 128) + n_blocks = triton.cdiv(N, 128) + + if b.ndim == 2: + b = b.view(G, N, K) + if b_scale.ndim == 2: + b_scale = b_scale.view(G, n_blocks, k_blocks) + + if a_scale.dtype == torch.int32: + a_scale = _unpack_int32_e8m0_scales(a_scale, k_blocks) + if b_scale.dtype == torch.int32: + b_scale = _unpack_int32_e8m0_scales(b_scale, k_blocks) + + if a_scale.dtype == torch.float8_e8m0fnu: + a_scale = _e8m0_to_fp32(a_scale) + if b_scale.dtype == torch.float8_e8m0fnu: + b_scale = _e8m0_to_fp32(b_scale) + + return a, a_scale.contiguous(), b, b_scale.contiguous() + + +def deepseek_v4_fp8_einsum_triton( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + a, a_scale, b, b_scale = _normalize_deepseek_v4_fp8_einsum_inputs( + a, a_scale, b, b_scale, out + ) + B, G, K = a.shape + N = out.shape[-1] + grid = (triton.cdiv(B, 16), G, triton.cdiv(N, 32)) + _deepseek_v4_fp8_einsum_triton_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + B, + G, + N, + K, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_B=16, + BLOCK_N=32, + BLOCK_K=128, + num_warps=4, + ) + + +@triton.jit +def _fp8_mqa_logits_kernel( + q_ptr, + k_ptr, + scale_ptr, + weights_ptr, + cu_seqlen_ks_ptr, + cu_seqlen_ke_ptr, + logits_ptr, + num_q: tl.constexpr, + seq_len_kv: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + stride_qm: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kn: tl.constexpr, + stride_kd: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_q + valid_n = offs_n < seq_len_kv + seq_start = tl.load(cu_seqlen_ks_ptr + offs_m, mask=valid_m, other=0) + seq_end = tl.load(cu_seqlen_ke_ptr + offs_m, mask=valid_m, other=0) + seq_mask = (offs_n[None, :] >= seq_start[:, None]) & ( + offs_n[None, :] < seq_end[:, None] + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + offs_m[:, None] * stride_qm + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + k_ptr + offs_n[:, None] * stride_kn + d[None, :] * stride_kd, + mask=valid_n[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, tl.trans(k), input_precision="tf32") + scale = tl.load(scale_ptr + offs_n, mask=valid_n, other=0.0) + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(seq_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_mqa_logits_triton( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + k_fp8, scale = kv + num_q, num_heads, head_dim = q.shape + seq_len_kv = k_fp8.shape[0] + logits = torch.empty( + (num_q, seq_len_kv), + device=q.device, + dtype=torch.float32, + ) + if num_q == 0 or seq_len_kv == 0: + return logits + + grid = (triton.cdiv(num_q, 8), triton.cdiv(seq_len_kv, 64)) + _fp8_mqa_logits_kernel[grid]( + q, + k_fp8, + scale, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + num_q, + seq_len_kv, + num_heads, + head_dim, + q.stride(0), + q.stride(1), + q.stride(2), + k_fp8.stride(0), + k_fp8.stride(1), + weights.stride(0), + weights.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=8, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _fp8_paged_mqa_logits_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_m = offs_m < num_rows + valid_n = offs_local_n < logits_width + batch = offs_m // next_n + q_pos = offs_m - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_m, + other=0, + ) + context_mask = valid_n[None, :] & (offs_n[None, :] < context_len[:, None]) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + + batch[:, None] * stride_btb + + block_rank[None, :] * stride_btk, + mask=valid_m[:, None] & valid_n[None, :], + other=0, + ) + + logits = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset[None, :] * stride_ss, + mask=context_mask, + other=0.0, + ) + for h in tl.range(0, num_heads): + scores = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch[:, None] * stride_qb + + q_pos[:, None] * stride_qn + + h * stride_qh + + d[None, :] * stride_qd, + mask=valid_m[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[:, :, None] * stride_kvb + + block_offset[None, :, None] * stride_kvs + + d[None, None, :] * stride_kvd, + mask=context_mask[:, :, None] & (d[None, None, :] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.sum(q[:, None, :] * k, axis=2) + weighted = tl.maximum(scores * scale, 0.0) + weight = tl.load( + weights_ptr + offs_m * stride_wm + h * stride_wh, + mask=valid_m, + other=0.0, + ) + logits += weighted * weight[:, None] + + store_mask = valid_m[:, None] & valid_n[None, :] + logits = tl.where(context_mask & store_mask, logits, float("-inf")) + tl.store( + logits_ptr + offs_m[:, None] * stride_lm + offs_local_n[None, :] * stride_ln, + logits, + mask=store_mask, + ) + + +def fp8_paged_mqa_logits_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + if next_n == 1 and head_dim % 64 == 0 and num_heads % 4 == 0: + return fp8_paged_mqa_logits_rowwise_triton( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + grid = (triton.cdiv(num_rows, 4), triton.cdiv(token_count, 64)) + _fp8_paged_mqa_logits_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_M=4, + BLOCK_N=64, + BLOCK_D=64, + num_warps=4, + ) + return logits + + +@triton.jit +def _fp8_paged_mqa_logits_rowwise_kernel( + q_ptr, + kv_ptr, + scale_ptr, + weights_ptr, + context_lens_ptr, + block_tables_ptr, + logits_ptr, + token_start, + num_rows: tl.constexpr, + logits_width: tl.constexpr, + next_n: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + block_size: tl.constexpr, + stride_qb: tl.constexpr, + stride_qn: tl.constexpr, + stride_qh: tl.constexpr, + stride_qd: tl.constexpr, + stride_kvb: tl.constexpr, + stride_kvs: tl.constexpr, + stride_kvd: tl.constexpr, + stride_sb: tl.constexpr, + stride_ss: tl.constexpr, + stride_wm: tl.constexpr, + stride_wh: tl.constexpr, + stride_clb: tl.constexpr, + stride_cln: tl.constexpr, + stride_btb: tl.constexpr, + stride_btk: tl.constexpr, + stride_lm: tl.constexpr, + stride_ln: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_H: tl.constexpr, +): + row = tl.program_id(0) + pid_n = tl.program_id(1) + offs_local_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_n = token_start + offs_local_n + offs_d = tl.arange(0, BLOCK_D) + + valid_row = row < num_rows + valid_n = offs_local_n < logits_width + batch = row // next_n + q_pos = row - batch * next_n + context_len = tl.load( + context_lens_ptr + batch * stride_clb + q_pos * stride_cln, + mask=valid_row, + other=0, + ) + context_mask = valid_n & (offs_n < context_len) + + block_rank = offs_n // block_size + block_offset = offs_n - block_rank * block_size + block_idx = tl.load( + block_tables_ptr + batch * stride_btb + block_rank * stride_btk, + mask=valid_row & valid_n, + other=0, + ) + + scale = tl.load( + scale_ptr + block_idx * stride_sb + block_offset * stride_ss, + mask=context_mask, + other=0.0, + ) + logits = tl.zeros((BLOCK_N,), dtype=tl.float32) + + for h0 in tl.range(0, num_heads, BLOCK_H): + heads = h0 + tl.arange(0, BLOCK_H) + valid_h = heads < num_heads + scores = tl.zeros((BLOCK_H, BLOCK_N), dtype=tl.float32) + for d0 in tl.range(0, head_dim, BLOCK_D): + d = d0 + offs_d + q = tl.load( + q_ptr + + batch * stride_qb + + q_pos * stride_qn + + heads[:, None] * stride_qh + + d[None, :] * stride_qd, + mask=valid_row & valid_h[:, None] & (d[None, :] < head_dim), + other=0.0, + ).to(tl.float32) + k = tl.load( + kv_ptr + + block_idx[None, :] * stride_kvb + + block_offset[None, :] * stride_kvs + + d[:, None] * stride_kvd, + mask=context_mask[None, :] & (d[:, None] < head_dim), + other=0.0, + ).to(tl.float32) + scores += tl.dot(q, k, input_precision="tf32") + + weighted = tl.maximum(scores * scale[None, :], 0.0) + weight = tl.load( + weights_ptr + row * stride_wm + heads * stride_wh, + mask=valid_row & valid_h, + other=0.0, + ) + logits += tl.sum(weighted * weight[:, None], axis=0) + + logits = tl.where(context_mask & valid_row, logits, float("-inf")) + tl.store( + logits_ptr + row * stride_lm + offs_local_n * stride_ln, + logits, + mask=valid_row & valid_n, + ) + + +def fp8_paged_mqa_logits_rowwise_triton( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + token_start: int = 0, + token_count: int | None = None, +) -> torch.Tensor: + batch_size, next_n, num_heads, head_dim = q.size() + kv_values, kv_scale = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_size, _, _ = kv_values.size() + num_rows = batch_size * next_n + if token_count is None: + token_count = max_model_len - token_start + assert token_start >= 0 + assert token_count >= 0 + assert token_start + token_count <= max_model_len + logits = torch.empty( + (num_rows, token_count), + device=q.device, + dtype=torch.float32, + ) + if num_rows == 0 or token_count == 0: + return logits + + context_lens_2d = context_lens.reshape(batch_size, -1) + if context_lens_2d.shape[1] == 1 and next_n != 1: + context_lens_2d = context_lens_2d.expand(batch_size, next_n).contiguous() + block_n = 128 + grid = (num_rows, triton.cdiv(token_count, block_n)) + _fp8_paged_mqa_logits_rowwise_kernel[grid]( + q, + kv_values, + kv_scale, + weights, + context_lens_2d, + block_tables, + logits, + token_start, + num_rows, + token_count, + next_n, + num_heads, + head_dim, + block_size, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + kv_values.stride(0), + kv_values.stride(1), + kv_values.stride(3), + kv_scale.stride(0), + kv_scale.stride(1), + weights.stride(0), + weights.stride(1), + context_lens_2d.stride(0), + context_lens_2d.stride(1), + block_tables.stride(0), + block_tables.stride(1), + logits.stride(0), + logits.stride(1), + BLOCK_N=block_n, + BLOCK_D=64, + BLOCK_H=8, + num_warps=4, + ) + return logits + + +@triton.jit +def _tf32_hc_prenorm_gemm_kernel( + x_ptr, + fn_ptr, + out_ptr, + sqrsum_ptr, + M: tl.constexpr, + K: tl.constexpr, + N: tl.constexpr, + stride_xm: tl.constexpr, + stride_xk: tl.constexpr, + stride_fnn: tl.constexpr, + stride_fnk: tl.constexpr, + stride_outs: tl.constexpr, + stride_outm: tl.constexpr, + stride_outn: tl.constexpr, + stride_sqs: tl.constexpr, + stride_sqm: tl.constexpr, + NUM_SPLIT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + pid_s = tl.program_id(2) + + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + split_k = tl.cdiv(K, NUM_SPLIT) + split_begin = pid_s * split_k + split_end = tl.minimum(split_begin + split_k, K) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + sq = tl.zeros((BLOCK_M,), dtype=tl.float32) + + for k0 in tl.range(0, split_k, BLOCK_K): + k = split_begin + k0 + offs_k + k_mask = k < split_end + x = tl.load( + x_ptr + offs_m[:, None] * stride_xm + k[None, :] * stride_xk, + mask=(offs_m[:, None] < M) & k_mask[None, :], + other=0.0, + ).to(tl.float32) + fn = tl.load( + fn_ptr + offs_n[None, :] * stride_fnn + k[:, None] * stride_fnk, + mask=(offs_n[None, :] < N) & k_mask[:, None], + other=0.0, + ).to(tl.float32) + + acc += tl.dot(x, fn, input_precision="tf32", out_dtype=tl.float32) + sq += tl.sum(x * x, axis=1) + + tl.store( + out_ptr + + pid_s * stride_outs + + offs_m[:, None] * stride_outm + + offs_n[None, :] * stride_outn, + acc, + mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), + ) + + if pid_n == 0: + tl.store( + sqrsum_ptr + pid_s * stride_sqs + offs_m * stride_sqm, + sq, + mask=offs_m < M, + ) + + +def tf32_hc_prenorm_gemm_triton( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> None: + assert x.dim() == 2 + assert fn.dim() == 2 + assert out.dim() == 3 + assert sqrsum.dim() == 2 + + m, k = x.shape + n = fn.shape[0] + assert fn.shape[1] == k + assert out.shape == (num_split, m, n) + assert sqrsum.shape == (num_split, m) + + if m == 0: + return + + block_m = 16 + block_n = triton.next_power_of_2(n) + block_n = min(max(block_n, 16), 32) + block_k = 64 + grid = (triton.cdiv(m, block_m), triton.cdiv(n, block_n), num_split) + _tf32_hc_prenorm_gemm_kernel[grid]( + x, + fn, + out, + sqrsum, + m, + k, + n, + x.stride(0), + x.stride(1), + fn.stride(0), + fn.stride(1), + out.stride(0), + out.stride(1), + out.stride(2), + sqrsum.stride(0), + sqrsum.stride(1), + num_split, + BLOCK_M=block_m, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=4, + ) diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 3487ac1766e6..416d871e24f4 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -769,6 +769,7 @@ def apply( sort_indices2=self.w2_g_idx_sort_indices, is_k_full=self.is_k_full, input_dtype=self.input_dtype, + clamp_limit=self.gemm1_clamp_limit, ) return diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 456f40bbf7a3..79edfa6f2d92 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -786,6 +786,19 @@ def update_expert_map(self): dp_size=get_dp_group().world_size, ) + @staticmethod + def _normalize_loaded_weight_for_copy( + expert_data: torch.Tensor, loaded_weight: torch.Tensor + ) -> torch.Tensor: + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if ( + e8m0_dtype is not None + and expert_data.dtype == torch.uint8 + and loaded_weight.dtype == e8m0_dtype + ): + return loaded_weight.view(torch.uint8) + return loaded_weight + def _load_per_tensor_weight_scale( self, shard_id: str, @@ -799,10 +812,12 @@ def _load_per_tensor_weight_scale( # We have to keep the weight scales of w1 and w3 because # we need to re-quantize w1/w3 weights after weight loading. idx = 0 if shard_id == "w1" else 1 - param_data[expert_id][idx] = loaded_weight + target = param_data[expert_id][idx] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) # If we are in the row parallel case (down_proj) elif shard_id == "w2": - param_data[expert_id] = loaded_weight + target = param_data[expert_id] + target.copy_(self._normalize_loaded_weight_for_copy(target, loaded_weight)) def _load_combined_w13_weight_scale( self, @@ -819,7 +834,7 @@ def _load_combined_w13_weight_scale( loaded_weight = loaded_weight.narrow( shard_dim, shard_size * tp_rank, shard_size ) - param.copy_(loaded_weight) + param.copy_(self._normalize_loaded_weight_for_copy(param, loaded_weight)) def _load_model_weight_or_group_weight_scale( self, @@ -986,7 +1001,9 @@ def _load_w13( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_w2( self, @@ -1022,7 +1039,9 @@ def _load_w2( hidden_dim=hidden_dim, shard_dim=shard_dim, ) - expert_data.copy_(loaded_weight) + expert_data.copy_( + self._normalize_loaded_weight_for_copy(expert_data, loaded_weight) + ) def _load_single_value( self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index d9aab35c25f4..e2df12488652 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -817,6 +817,35 @@ def get_w8a8_block_fp8_configs( return None +def _get_default_w8a8_block_fp8_config( + M: int, + block_n: int, + block_k: int, +) -> dict[str, Any]: + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_n and + # BLOCK_SIZE_K must be divisible by block_k. + # M-aware tuning for low-M decode: BLOCK_SIZE_M=64 wastes most of the + # M-dim for single-request decode and short MTP-style draft batches. SM12x + # keeps benefiting from the low-M tile through M=32 on DeepSeek V4 shapes. + capability = current_platform.get_device_capability() + capability_major = getattr(capability, "major", None) + if capability_major is None and capability is not None: + capability_major = capability[0] + low_m_limit = 32 if capability_major == 12 else 8 + if low_m_limit >= M: + block_m, num_stages = 16, (2 if current_platform.is_rocm() else 3) + else: + block_m, num_stages = 64, 2 + return { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": num_stages, + } + + def w8a8_triton_block_scaled_mm( A: torch.Tensor, B: torch.Tensor, @@ -861,6 +890,12 @@ def w8a8_triton_block_scaled_mm( N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if e8m0_dtype is not None: + if As.dtype == e8m0_dtype: + As = _upcast_e8m0_to_fp32(As) + if Bs.dtype == e8m0_dtype: + Bs = _upcast_e8m0_to_fp32(Bs) C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) @@ -870,17 +905,7 @@ def w8a8_triton_block_scaled_mm( # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - # Default config - # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] - # BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 2, - } + config = _get_default_w8a8_block_fp8_config(M, block_size[0], block_size[1]) def grid(META): return ( @@ -1215,6 +1240,8 @@ def create_fp8_scale_parameter( if dtype == torch.float32: scale[:] = torch.finfo(torch.float32).min + elif dtype == getattr(torch, "float8_e8m0fnu", None): + scale[:] = 0 set_weight_attrs(scale, {"scale_type": "weight_scale"}) return scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index 4bf52a49c43f..f974c873fdc1 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -4,7 +4,6 @@ import torch -import vllm.envs as envs from vllm._aiter_ops import rocm_aiter_ops from vllm.forward_context import get_forward_context from vllm.logger import init_logger @@ -12,7 +11,9 @@ from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( fp8_fp4_mqa_logits, + fp8_fp4_mqa_topk_indices, fp8_fp4_paged_mqa_logits, + fp8_fp4_paged_mqa_topk_indices, has_deep_gemm, ) from vllm.utils.torch_utils import ( @@ -23,6 +24,7 @@ ) from vllm.v1.attention.backends.mla.indexer import ( DeepseekV32IndexerMetadata, + sparse_indexer_max_logits_bytes, ) from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton from vllm.v1.worker.workspace import current_workspace_manager @@ -35,11 +37,60 @@ logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 +SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH = 4096 +SM120_SHORT_ROW_TOPK_MAX_WIDTH = 12288 # MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32. MXFP4_BLOCK_SIZE = 32 +def _should_use_sm120_short_row_topk_decode( + topk_tokens: int, + logits_width: int, + num_rows: int, + is_cuda_sm120: bool, +) -> bool: + if not is_cuda_sm120 or topk_tokens != 512: + return False + if logits_width <= SM120_SHORT_ROW_TOPK_ALWAYS_WIDTH: + return True + return logits_width < SM120_SHORT_ROW_TOPK_MAX_WIDTH + + +def _use_sm120_short_row_topk_decode( + logits: torch.Tensor, + topk_tokens: int, +) -> bool: + return _should_use_sm120_short_row_topk_decode( + topk_tokens, + logits.shape[1], + logits.shape[0], + current_platform.is_cuda() + and current_platform.is_device_capability_family(120), + ) + + +def _decode_logits_width(max_model_len: int, max_seq_len: int) -> int: + if max_model_len <= 0: + return 0 + if max_seq_len <= 0: + return max_model_len + return min(max_model_len, max_seq_len) + + +def _decode_topk_logits_width( + max_model_len: int, max_seq_len: int, topk_tokens: int +) -> int: + logits_width = _decode_logits_width(max_model_len, max_seq_len) + return min(max_model_len, max(logits_width, topk_tokens)) + + +def _sparse_indexer_requires_deep_gemm() -> bool: + return current_platform.is_cuda() and not ( + current_platform.is_device_capability_family(120) + ) + + def _gather_workspace_shapes( total_seq_lens: int, head_dim: int, @@ -118,7 +169,7 @@ def sparse_attn_indexer( # Dummy allocation to simulate for peak logits tensor memory during inference. # FP8 elements so elements == bytes - max_logits_elems = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_elems = sparse_indexer_max_logits_bytes() _ = torch.empty( max_logits_elems, dtype=torch.uint8, device=hidden_states.device ) @@ -220,6 +271,19 @@ def sparse_attn_indexer( q_slice_cast = q_slice k_quant_cast = k_quant k_scale_cast = k_scale.view(torch.float32).squeeze(-1) + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + if fp8_fp4_mqa_topk_indices( + (q_slice_cast, q_scale_slice), + (k_quant_cast, k_scale_cast), + weights[chunk.token_start : chunk.token_end], + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + ): + continue + logits = fp8_fp4_mqa_logits( (q_slice_cast, q_scale_slice), (k_quant_cast, k_scale_cast), @@ -230,10 +294,6 @@ def sparse_attn_indexer( ) num_rows = logits.shape[0] - topk_indices = topk_indices_buffer[ - chunk.token_start : chunk.token_end, :topk_tokens - ] - if current_platform.is_xpu(): xpu_ops.top_k_per_row_prefill( # type: ignore[attr-defined] logits, @@ -307,35 +367,38 @@ def sparse_attn_indexer( if use_fp4_cache else padded_q_quant_decode_tokens ) - logits = fp8_fp4_paged_mqa_logits( - (padded_q_quant_cast, padded_q_scale), - kv_cache, - weights[:num_padded_tokens], - seq_lens, - decode_metadata.block_table, - decode_metadata.schedule_metadata, - max_model_len=max_model_len, - clean_logits=False, - ) - num_rows = logits.shape[0] topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] - - if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048): - workspace_manager = current_workspace_manager() - (topk_workspace,) = workspace_manager.get_simultaneous( - ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), - ) - torch.ops._C.persistent_topk( - logits, + logits_width = _decode_topk_logits_width( + max_model_len, attn_metadata_narrowed.max_seq_len, topk_tokens + ) + logits_bytes = num_padded_tokens * logits_width * torch.float32.itemsize + used_direct_topk = False + if logits_bytes > sparse_indexer_max_logits_bytes(): + used_direct_topk = fp8_fp4_paged_mqa_topk_indices( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], seq_lens, + decode_metadata.block_table, + logits_width, topk_indices, - topk_workspace, - topk_tokens, - attn_metadata_narrowed.max_seq_len, ) - else: - if current_platform.is_xpu(): - xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + + if not used_direct_topk: + logits = fp8_fp4_paged_mqa_logits( + (padded_q_quant_cast, padded_q_scale), + kv_cache, + weights[:num_padded_tokens], + seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=logits_width, + clean_logits=False, + ) + num_rows = logits.shape[0] + + if _use_sm120_short_row_topk_decode(logits, topk_tokens): + torch.ops._C.top_k_per_row_decode( logits, next_n, seq_lens, @@ -345,17 +408,42 @@ def sparse_attn_indexer( logits.stride(1), topk_tokens, ) - else: - torch.ops._C.top_k_per_row_decode( + elif current_platform.is_cuda() and topk_tokens in (512, 2048): + workspace_manager = current_workspace_manager() + (topk_workspace,) = workspace_manager.get_simultaneous( + ((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8), + ) + torch.ops._C.persistent_topk( logits, - next_n, seq_lens, topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), + topk_workspace, topk_tokens, + logits_width, ) + else: + if current_platform.is_xpu(): + xpu_ops.top_k_per_row_decode( # type: ignore[attr-defined] + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + else: + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) if decode_metadata.requires_padding: # if padded, we need to unpack @@ -438,7 +526,7 @@ def __init__( self.topk_indices_buffer = topk_indices_buffer self.skip_k_cache_insert = skip_k_cache_insert self.use_fp4_cache = use_fp4_cache - if current_platform.is_cuda() and not has_deep_gemm(): + if _sparse_indexer_requires_deep_gemm() and not has_deep_gemm(): raise RuntimeError( "Sparse Attention Indexer CUDA op requires DeepGEMM to be installed." ) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index cef4038dc2e6..1c3adb3ac4b0 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -8,10 +8,12 @@ import torch import torch.nn as nn +from vllm import envs from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) @@ -55,10 +57,14 @@ from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op +from .interfaces import SupportsPP from .utils import ( AutoWeightsLoader, + PPMissingLayer, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) @@ -703,6 +709,25 @@ def _deepseek_v4_mega_moe_experts_op_fake( ) +def _use_deepseek_v4_mega_moe(vllm_config: VllmConfig) -> bool: + use_mega_moe = ( + vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" + ) + + env_name = "VLLM_DEEPSEEK_V4_USE_MEGA_MOE" + if envs.is_set(env_name): + use_mega_moe = envs.VLLM_DEEPSEEK_V4_USE_MEGA_MOE + + if use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: + raise NotImplementedError( + "DeepSeek V4 MegaMoE currently requires expert parallel. " + "Enable it with --enable-expert-parallel, or pick a different " + "moe backend." + ) + + return use_mega_moe + + class DeepseekV4MoE(nn.Module): def __init__( self, @@ -715,15 +740,7 @@ def __init__( config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.prefix = prefix - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) - if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently requires expert parallel. " - "Enable it with --enable-expert-parallel, or pick a different " - "moe backend." - ) + self.use_mega_moe = _use_deepseek_v4_mega_moe(vllm_config) self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.hidden_size = config.hidden_size @@ -1226,15 +1243,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config - self.use_mega_moe = ( - vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" - ) - if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel: - raise NotImplementedError( - "DeepSeek V4 MegaMoE currently requires expert parallel. " - "Enable it with --enable-expert-parallel, or pick a different " - "moe backend." - ) + self.use_mega_moe = _use_deepseek_v4_mega_moe(vllm_config) self.vocab_size = config.vocab_size self.hc_eps = config.hc_eps self.hc_mult = config.hc_mult @@ -1261,12 +1270,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): device=self.device, ) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens", - ) + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + else: + self.embed_tokens = PPMissingLayer() self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, @@ -1279,7 +1291,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( + ["hidden_states"], self.hc_dim + ) self.hc_head_fn = nn.Parameter( torch.empty( @@ -1304,26 +1323,37 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. - self._mtp_hidden_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - self.hc_dim, - dtype=vllm_config.model_config.dtype, - device=self.device, - ) + if get_pp_group().is_last_rank: + self._mtp_hidden_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + self.hc_dim, + dtype=vllm_config.model_config.dtype, + device=self.device, + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, - input_ids: torch.Tensor, + input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: - hidden_states = self.embed_input_ids(input_ids) - hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) - if self.use_mega_moe: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_input_ids(input_ids) + hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"].view( + -1, self.hc_mult, self.config.hidden_size + ) + + if self.use_mega_moe and input_ids is not None: input_ids = input_ids.to(torch.int64) for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( @@ -1332,6 +1362,9 @@ def forward( input_ids, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states.flatten(1)}) + # Stash pre-hc_head residual for the MTP draft (captured copy_). num_tokens = hidden_states.shape[0] self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) @@ -1379,6 +1412,8 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + break param = params_dict[name] weight_loader = param.weight_loader @@ -1396,11 +1431,15 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: and loaded_weight.dtype == torch.float8_e8m0fnu ): loaded_weight = loaded_weight.view(torch.uint8) + skip_expert_weight = False for mapping in expert_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name_mapped = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name_mapped, self): + skip_expert_weight = True + break param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other @@ -1419,15 +1458,21 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if success: name = name_mapped break + if skip_expert_weight: + continue loaded_params.add(name_mapped) continue elif "attn_sink" in name: + if is_pp_missing_parameter(name, self): + continue narrow_weight = loaded_weight[head_rank_start:head_rank_end] n = narrow_weight.shape[0] params_dict[name][:n].copy_(narrow_weight) loaded_params.add(name) continue else: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -1525,7 +1570,7 @@ def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: ) -class DeepseekV4ForCausalLM(nn.Module): +class DeepseekV4ForCausalLM(nn.Module, SupportsPP): model_cls = DeepseekV4Model # Default mapper assumes the original FP4-expert checkpoint layout. @@ -1544,12 +1589,18 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - self.lm_head = ParallelLMHead( - config.vocab_size, - config.hidden_size, - prefix=maybe_prefix(prefix, "lm_head"), - ) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + else: + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) @@ -1563,7 +1614,7 @@ def compute_logits( def forward( self, - input_ids: torch.Tensor, + input_ids: torch.Tensor | None, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index cd51f106503a..e5962c7c0664 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -30,7 +30,7 @@ ), "deepseek_v4": ( "deepseek_v3_reasoning_parser", - "DeepSeekV3ReasoningParser", + "DeepSeekV3ReasoningWithThinkingParser", ), "poolside_v1": ( "poolside_v1_reasoning_parser", diff --git a/vllm/tokenizers/deepseek_v4_encoding.py b/vllm/tokenizers/deepseek_v4_encoding.py index 6895771e2f59..74f01ce017e4 100644 --- a/vllm/tokenizers/deepseek_v4_encoding.py +++ b/vllm/tokenizers/deepseek_v4_encoding.py @@ -155,10 +155,15 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str: p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}' P_dsml_strs = [] - if isinstance(tool_call["arguments"], str): - arguments = json.loads(tool_call["arguments"]) + raw_arguments = tool_call.get("arguments") + if raw_arguments is None or raw_arguments == "": + arguments = {} + elif isinstance(raw_arguments, str): + arguments = json.loads(raw_arguments) + if arguments is None: + arguments = {} else: - arguments = tool_call["arguments"] + arguments = raw_arguments for k, v in arguments.items(): p_dsml_str = p_dsml_template.format( diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..7df618e58e1a 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -338,6 +338,253 @@ def transform_sf_into_required_layout(*args, **kwargs): ) +_SM120_MQA_LOGITS_MAX_SCORE_BYTES = 64 * 1024 * 1024 +_SM120_PAGED_MQA_TOPK_CHUNK_SIZE = 8192 + + +def _fp8_mqa_logits_head_chunk_size( + seq_len: int, + seq_len_kv: int, + num_heads: int, +) -> int: + # The SM120 torch path is used on long prefill paths where materializing + # [head_chunk, M, N] scores can otherwise allocate multiple GiB. Keep the + # transient score tensor bounded, while still using larger head chunks for + # short prompts where they are faster. + score_elems_per_head = max(1, seq_len * seq_len_kv) + max_heads = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_head * 4) + return max(1, min(8, num_heads, max_heads)) + + +def _fp8_mqa_logits_k_chunk_size( + seq_len: int, + seq_len_kv: int, + head_chunk_size: int, +) -> int: + score_elems_per_key = max(1, seq_len * head_chunk_size) + max_keys = _SM120_MQA_LOGITS_MAX_SCORE_BYTES // (score_elems_per_key * 4) + return max(1, min(seq_len_kv, max_keys)) + + +def _fp8_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA logits torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + logits = torch.zeros( + (seq_len, seq_len_kv), device=q_values.device, dtype=torch.float32 + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + k_chunk_size = _fp8_mqa_logits_k_chunk_size( + seq_len, seq_len_kv, head_end - head_start + ) + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + logits[:, k_start:k_end].add_( + scores[0] if scores.shape[0] == 1 else scores.sum(dim=0) + ) + + if clean_logits: + offsets = torch.arange(seq_len_kv, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + logits = logits.masked_fill(~valid, float("-inf")) + + return logits + + +def _fp8_mqa_logits_topk_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_tokens: int, + out: torch.Tensor | None = None, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 MQA top-k torch path only supports FP8 Q") + + k_values, k_scales = kv + k_f32 = k_values.to(torch.float32) + k_f32.mul_(k_scales.reshape(-1, 1).to(torch.float32)) + k_t = k_f32.transpose(0, 1).contiguous() + + seq_len, num_heads, _ = q_values.shape + seq_len_kv = k_f32.shape[0] + if out is None: + out = torch.empty( + (seq_len, topk_tokens), device=q_values.device, dtype=torch.int32 + ) + else: + assert out.shape == (seq_len, topk_tokens) + assert out.dtype == torch.int32 + out.fill_(-1) + + best_values = torch.full( + (seq_len, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + head_chunk_size = _fp8_mqa_logits_head_chunk_size(seq_len, seq_len_kv, num_heads) + k_chunk_size = _fp8_mqa_logits_k_chunk_size(seq_len, seq_len_kv, head_chunk_size) + max_chunk_topk = min(topk_tokens, k_chunk_size) + chunk_values_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (seq_len, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (seq_len, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (seq_len, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + for k_start in range(0, seq_len_kv, k_chunk_size): + k_end = min(k_start + k_chunk_size, seq_len_kv) + chunk_logits = torch.zeros( + (seq_len, k_end - k_start), + device=q_values.device, + dtype=torch.float32, + ) + for head_start in range(0, num_heads, head_chunk_size): + head_end = min(head_start + head_chunk_size, num_heads) + q_chunk = q_values[:, head_start:head_end, :].to(torch.float32) + q_chunk = q_chunk.transpose(0, 1).contiguous() + head_weights = weights[:, head_start:head_end].transpose(0, 1).unsqueeze(-1) + scores = torch.matmul(q_chunk, k_t[:, k_start:k_end]) + scores.relu_() + scores.mul_(head_weights) + chunk_logits.add_(scores[0] if scores.shape[0] == 1 else scores.sum(dim=0)) + + offsets = torch.arange(k_start, k_end, device=q_values.device) + valid = (offsets[None, :] >= cu_seqlen_ks[:, None]) & ( + offsets[None, :] < cu_seqlen_ke[:, None] + ) + chunk_logits.masked_fill_(~valid, float("-inf")) + + chunk_topk = min(topk_tokens, k_end - k_start) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(k_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(out) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=out) + best_values, next_best_values = next_best_values, best_values + out.masked_fill_(~torch.isfinite(best_values), -1) + + return out + + +def fp8_fp4_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 MQA top-k indices without materializing full logits.""" + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q[1] is None + ): + return False + _fp8_mqa_logits_topk_torch( + q, + kv, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + topk_indices.shape[1], + out=topk_indices, + ) + return True + + +def _fp8_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + clean_logits: bool, +) -> torch.Tensor: + q_values, q_scale = q + if clean_logits and q_scale is None and q_values.dim() == 3 and kv[0].dim() == 2: + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_mqa_logits_triton, + ) + + return fp8_mqa_logits_triton(q_values, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + return _fp8_mqa_logits_torch( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) + + def fp8_fp4_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv: tuple[torch.Tensor, torch.Tensor], @@ -370,6 +617,10 @@ def fp8_fp4_mqa_logits( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_mqa_logits_sm12x( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits + ) _lazy_init() if _fp8_fp4_mqa_logits_impl is None: return _missing() @@ -404,6 +655,215 @@ def get_paged_mqa_logits_metadata( return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms) +def _fp8_paged_mqa_logits_torch( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if q_scale is not None: + raise NotImplementedError("SM120 paged MQA torch path only supports FP8 Q") + + batch_size, next_n, num_heads, head_dim = q_values.shape + head_dim_with_scale = kv_cache.shape[-1] + assert head_dim_with_scale > head_dim + assert weights.shape == (batch_size * next_n, num_heads) + assert context_lens.shape == (batch_size, next_n) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + _view_packed_fp8_paged_mqa_kv_cache, + ) + + kv_values, kv_scales = _view_packed_fp8_paged_mqa_kv_cache(kv_cache, head_dim) + _, block_kv, _, _ = kv_values.shape + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + + q_f32 = q_values.float() + score_bytes = _SM120_MQA_LOGITS_MAX_SCORE_BYTES + max_tokens_per_chunk = max(1, score_bytes // max(1, num_heads * 4)) + token_offsets_cache: dict[int, torch.Tensor] = {} + + for batch_idx in range(batch_size): + for next_idx in range(next_n): + row = batch_idx * next_n + next_idx + context_len = int(context_lens[batch_idx, next_idx].item()) + if context_len <= 0: + continue + + q_row = q_f32[batch_idx, next_idx] + row_weights = weights[row] + for token_start in range(0, context_len, max_tokens_per_chunk): + token_end = min(context_len, token_start + max_tokens_per_chunk) + chunk_len = token_end - token_start + token_offsets = token_offsets_cache.get(chunk_len) + if token_offsets is None or token_offsets.device != q_values.device: + token_offsets = torch.arange( + chunk_len, device=q_values.device, dtype=torch.long + ) + token_offsets_cache[chunk_len] = token_offsets + token_ids = token_start + token_offsets + logical_blocks = token_ids // block_kv + token_in_block = token_ids - logical_blocks * block_kv + physical_blocks = block_tables[batch_idx, logical_blocks] + kv_chunk = kv_values[physical_blocks, token_in_block, 0].float() + scale_chunk = kv_scales[physical_blocks, token_in_block, 0].squeeze(-1) + kv_chunk.mul_(scale_chunk[:, None]) + scores = torch.matmul(q_row, kv_chunk.T) + scores.relu_() + scores.mul_(row_weights[:, None]) + logits[row, token_start:token_end] = scores.sum(dim=0) + + return logits + + +def _fp8_paged_mqa_logits_sm12x( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + q_values, q_scale = q + if ( + q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_triton, + ) + + return fp8_paged_mqa_logits_triton( + q_values, kv_cache, weights, context_lens, block_tables, max_model_len + ) + return _fp8_paged_mqa_logits_torch( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) + + +def fp8_fp4_paged_mqa_topk_indices( + q: tuple[torch.Tensor, torch.Tensor | None], + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, + topk_indices: torch.Tensor, +) -> bool: + """Write SM120 FP8 paged MQA top-k indices without full logits.""" + q_values, q_scale = q + if not ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + and q_scale is None + and q_values.dim() == 4 + and kv_cache.dtype == torch.uint8 + and kv_cache.shape[-1] == q_values.shape[-1] + 4 + ): + return False + + num_rows = q_values.shape[0] * q_values.shape[1] + topk_tokens = topk_indices.shape[1] + assert topk_indices.shape == (num_rows, topk_tokens) + assert topk_indices.dtype == torch.int32 + topk_indices.fill_(-1) + if num_rows == 0 or topk_tokens == 0 or max_model_len == 0: + return True + + best_values = torch.full( + (num_rows, topk_tokens), + float("-inf"), + device=q_values.device, + dtype=torch.float32, + ) + chunk_size = max(1, _SM120_PAGED_MQA_TOPK_CHUNK_SIZE) + max_chunk_topk = min(topk_tokens, chunk_size) + chunk_values_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + chunk_indices_buf = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int64, + ) + chunk_indices_i32 = torch.empty( + (num_rows, max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + candidate_values = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.float32, + ) + candidate_indices = torch.empty( + (num_rows, topk_tokens + max_chunk_topk), + device=q_values.device, + dtype=torch.int32, + ) + next_best_values = torch.empty_like(best_values) + selected = torch.empty( + (num_rows, topk_tokens), + device=q_values.device, + dtype=torch.int64, + ) + + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + fp8_paged_mqa_logits_triton, + ) + + for token_start in range(0, max_model_len, chunk_size): + token_count = min(chunk_size, max_model_len - token_start) + chunk_logits = fp8_paged_mqa_logits_triton( + q_values, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + token_start=token_start, + token_count=token_count, + ) + chunk_topk = min(topk_tokens, token_count) + chunk_values = chunk_values_buf[:, :chunk_topk] + chunk_indices = chunk_indices_buf[:, :chunk_topk] + torch.topk(chunk_logits, chunk_topk, dim=1, out=(chunk_values, chunk_indices)) + chunk_indices_out = chunk_indices_i32[:, :chunk_topk] + chunk_indices_out.copy_(chunk_indices) + chunk_indices_out.add_(token_start) + + candidate_cols = topk_tokens + chunk_topk + candidate_values_view = candidate_values[:, :candidate_cols] + candidate_indices_view = candidate_indices[:, :candidate_cols] + candidate_values_view[:, :topk_tokens].copy_(best_values) + candidate_values_view[:, topk_tokens:candidate_cols].copy_(chunk_values) + candidate_indices_view[:, :topk_tokens].copy_(topk_indices) + candidate_indices_view[:, topk_tokens:candidate_cols].copy_(chunk_indices_out) + torch.topk( + candidate_values_view, + topk_tokens, + dim=1, + out=(next_best_values, selected), + ) + torch.gather(candidate_indices_view, 1, selected, out=topk_indices) + best_values, next_best_values = next_best_values, best_values + topk_indices.masked_fill_(~torch.isfinite(best_values), -1) + + return True + + def fp8_fp4_paged_mqa_logits( q: tuple[torch.Tensor, torch.Tensor | None], kv_cache: torch.Tensor, @@ -425,9 +885,10 @@ def fp8_fp4_paged_mqa_logits( [B, next_n, H, D] float8_e4m3fn and q_scale is None. FP4 path: q_values is packed uint8 and q_scale is the companion block-scale tensor. - kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, 1, - D+4], dtype `torch.uint8`, with the last 4 bytes per (block, pos) - storing the float dequant scale. + kv_cache: Paged KV-cache. FP8 layout is [num_blocks, block_size, D+4] + or [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. Within + each block, the D-byte FP8 values for every token are stored first, + followed by per-token fp32 scale bytes. weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. context_lens: Tensor of shape [B], dtype int32; effective context length for each batch element. @@ -442,6 +903,10 @@ def fp8_fp4_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ + if current_platform.is_device_capability_family(120) and q[1] is None: + return _fp8_paged_mqa_logits_sm12x( + q, kv_cache, weights, context_lens, block_tables, max_model_len + ) _lazy_init() if _fp8_fp4_paged_mqa_logits_impl is None: return _missing() @@ -457,6 +922,52 @@ def fp8_fp4_paged_mqa_logits( ) +def _tf32_hc_prenorm_gemm_torch( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + """Portable SM12x HyperConnection prenorm GEMM fallback. + + DeepGEMM's split ABI only requires that downstream consumers recover the + full result by summing over the split dimension. Keep the implementation + simple by writing the full product to split zero and clearing the rest. + """ + del num_split + product = x.float() @ fn.float().T + norm = x.float().square().sum(dim=-1) + + if out.dim() == 3: + out.zero_() + sqrsum.zero_() + out[0].copy_(product) + sqrsum[0].copy_(norm) + else: + out.copy_(product) + sqrsum.copy_(norm) + return out + + +def _tf32_hc_prenorm_gemm_sm12x( + x: torch.Tensor, + fn: torch.Tensor, + out: torch.Tensor, + sqrsum: torch.Tensor, + num_split: int, +) -> torch.Tensor: + if out.dim() == 3 and sqrsum.dim() == 2: + from vllm.model_executor.layers.deepseek_v4_triton_kernels import ( + tf32_hc_prenorm_gemm_triton, + ) + + tf32_hc_prenorm_gemm_triton(x, fn, out, sqrsum, num_split) + return out + + return _tf32_hc_prenorm_gemm_torch(x, fn, out, sqrsum, num_split) + + def tf32_hc_prenorm_gemm( x: torch.Tensor, fn: torch.Tensor, @@ -471,6 +982,8 @@ def tf32_hc_prenorm_gemm( See the caller function for shape requirement """ + if current_platform.is_device_capability_family(120): + return _tf32_hc_prenorm_gemm_sm12x(x, fn, out, sqrsum, num_split) _lazy_init() if _tf32_hc_prenorm_gemm_impl is None: return _missing() @@ -570,7 +1083,9 @@ def should_use_deepgemm_for_fp8_linear( "m_grouped_fp8_fp4_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", "fp8_fp4_mqa_logits", + "fp8_fp4_mqa_topk_indices", "fp8_fp4_paged_mqa_logits", + "fp8_fp4_paged_mqa_topk_indices", "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 474a5b2d421e..b2399c683eb5 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -30,6 +30,10 @@ SparseMLAAttentionImpl, ) from vllm.v1.attention.backends.mla.compressor_utils import get_compressed_slot_mapping +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_cudagraphs_allowed, +) from vllm.v1.attention.backends.mla.sparse_utils import ( triton_convert_req_index_to_global_index, ) @@ -266,6 +270,20 @@ def get_prefill_workspace_size(max_model_len: int): class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]): _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + and not triton_sparse_mla_cudagraphs_allowed(vllm_config) + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__( self, kv_cache_spec: AttentionSpec, diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 7c0715a9e8b6..3f43b65d26fb 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass import torch -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform @@ -32,6 +32,28 @@ logger = init_logger(__name__) +def sparse_indexer_max_logits_bytes(is_sm12x: bool | None = None) -> int: + configured_mb = os.getenv("VLLM_SPARSE_INDEXER_MAX_LOGITS_MB") + if configured_mb is not None: + return int(configured_mb) * 1024 * 1024 + + if is_sm12x is None: + is_sm12x = ( + current_platform.is_cuda() + and current_platform.is_device_capability_family(120) + ) + default_mb = 256 if is_sm12x else 512 + return default_mb * 1024 * 1024 + + +def _uses_deep_gemm_scheduler_metadata() -> bool: + return ( + current_platform.is_cuda() + and has_deep_gemm() + and not current_platform.is_device_capability_family(120) + ) + + @triton.jit def _prepare_uniform_decode_kernel( seq_lens_ptr, @@ -269,13 +291,12 @@ def __init__(self, *args, **kwargs): self.reorder_batch_threshold += self.num_speculative_tokens # NOTE(zyongye) fp4 indexer cache only natively supports next_n in # natively_supported_next_n_fp4; for other next_n values we fall back - # to the flattening path. Outside the SM100 datacenter family the FP8 - # paged MQA logits kernel has the same [1, 2] constraint (deepgemm - # smxx_fp8_fp4_paged_mqa_logits.hpp:233), so flatten there too. + # to the flattening path. When fp4 indexer cache is disabled, the + # native (non-flattening) path handles all next_n values. self.use_flattening = ( self.use_fp4_indexer_cache - or not current_platform.is_device_capability_family(100) - ) and next_n not in self.natively_supported_next_n_fp4 + and next_n not in self.natively_supported_next_n_fp4 + ) sm_count = num_compute_units(self.device.index) self.num_sms = sm_count @@ -520,7 +541,7 @@ def build( prefill_query_lens_cpu = torch.diff( query_start_loc_cpu[num_decodes : num_decodes + num_prefills + 1] ) - max_logits_bytes = envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + max_logits_bytes = sparse_indexer_max_logits_bytes() # Upper bound is exact for prefill rows (the `[num_decodes:]` # slice below). assert common_attn_metadata.seq_lens_cpu_upper_bound is not None @@ -610,8 +631,7 @@ def build( if seq_lens.dim() == 1: seq_lens = seq_lens.unsqueeze(-1) - # DeepGEMM is required for the paged MQA logits on CUDA devices - if current_platform.is_cuda() and has_deep_gemm(): + if _uses_deep_gemm_scheduler_metadata(): self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( seq_lens, self.kv_cache_spec.storage_block_size, diff --git a/vllm/v1/attention/backends/mla/sparse_mla_env.py b/vllm/v1/attention/backends/mla/sparse_mla_env.py new file mode 100644 index 000000000000..52af38e4cea3 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_env.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Environment controls for the portable Triton sparse MLA path.""" + +import os + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +_TRITON_MLA_SPARSE_ENV = "VLLM_TRITON_MLA_SPARSE" +_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_TOPK_CHUNK_SIZE" +_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV = "VLLM_TRITON_MLA_SPARSE_QUERY_CHUNK_SIZE" +_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV = ( + "VLLM_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH" +) +_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV = "VLLM_TRITON_MLA_SPARSE_HEAD_BLOCK_SIZE" +_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV = "VLLM_TRITON_MLA_SPARSE_MATMUL_DECODE" + +_ENV_TRUE_VALUES = {"1", "true", "yes", "on"} +_ENV_FALSE_VALUES = {"0", "false", "no", "off"} + +logger = init_logger(__name__) + + +def _optional_env_flag(name: str) -> bool | None: + raw_value = os.getenv(name) + if raw_value is None: + return None + value = raw_value.lower() + if value in _ENV_TRUE_VALUES: + return True + if value in _ENV_FALSE_VALUES: + return False + return None + + +def _is_sm12x_device(device: torch.device) -> bool: + if not torch.cuda.is_available(): + return False + index = device.index if device.index is not None else torch.cuda.current_device() + return torch.cuda.get_device_capability(index)[0] == 12 + + +def triton_sparse_mla_configured() -> bool | None: + return _optional_env_flag(_TRITON_MLA_SPARSE_ENV) + + +def is_triton_sparse_mla_enabled_for_platform() -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) + + +def is_triton_sparse_mla_enabled(device: torch.device) -> bool: + configured = triton_sparse_mla_configured() + if configured is not None: + return configured + return _is_sm12x_device(device) + + +def _uses_speculative_decoding(vllm_config) -> bool: + return bool(getattr(vllm_config, "speculative_config", None)) + + +def triton_sparse_mla_cudagraphs_allowed(vllm_config=None) -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV) + if configured is not None: + return configured + return not ( + vllm_config is not None and _uses_speculative_decoding(vllm_config) + ) + + +def disable_triton_sparse_mla_cudagraphs_if_enabled(vllm_config) -> None: + if not is_triton_sparse_mla_enabled_for_platform(): + return + if triton_sparse_mla_cudagraphs_allowed(vllm_config): + logger.warning_once( + "Keeping vLLM compile and CUDA graphs enabled for the DeepSeek V4 " + "Triton sparse MLA path because " + f"{_TRITON_MLA_SPARSE_ALLOW_CUDAGRAPH_ENV}=1 or speculative " + "decoding is not configured. This is an " + "experimental performance mode." + ) + return + + from vllm.config.compilation import CompilationMode, CUDAGraphMode + + compilation_config = vllm_config.compilation_config + if ( + compilation_config.mode == CompilationMode.NONE + and compilation_config.cudagraph_mode == CUDAGraphMode.NONE + ): + return + + logger.warning_once( + "Disabling vLLM compile and CUDA graphs for the DeepSeek V4 Triton " + "sparse MLA path because the current Triton sparse MLA path is not " + "compile/graph-safe yet, or because speculative decoding uses " + "multi-token sparse MLA decode." + ) + compilation_config.mode = CompilationMode.NONE + compilation_config.compile_sizes = [] + compilation_config.compile_ranges_endpoints = [] + compilation_config.cudagraph_mode = CUDAGraphMode.NONE + compilation_config.cudagraph_capture_sizes = [] + compilation_config.max_cudagraph_capture_size = 0 + + +def triton_sparse_mla_topk_chunk_size() -> int: + raw_value = os.getenv(_TRITON_MLA_SPARSE_TOPK_CHUNK_ENV) + if raw_value is None: + return 512 + try: + return max(1, int(raw_value)) + except ValueError: + return 512 + + +def triton_sparse_mla_query_chunk_size() -> int: + raw_value = os.getenv(_TRITON_MLA_SPARSE_QUERY_CHUNK_ENV) + if raw_value is None: + return 256 + try: + return max(1, int(raw_value)) + except ValueError: + return 256 + + +def triton_sparse_mla_head_block_size() -> int | None: + raw_value = os.getenv(_TRITON_MLA_SPARSE_HEAD_BLOCK_ENV) + if raw_value is None: + return None + try: + value = int(raw_value) + except ValueError: + return None + if value in (1, 2, 4): + return value + return None + + +def triton_sparse_mla_matmul_decode_enabled() -> bool: + configured = _optional_env_flag(_TRITON_MLA_SPARSE_MATMUL_DECODE_ENV) + if configured is not None: + return configured + return current_platform.is_device_capability_family(120) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_kernels.py b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py new file mode 100644 index 000000000000..834ecda43032 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_kernels.py @@ -0,0 +1,2694 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Portable sparse MLA Triton kernels.""" + +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + triton_sparse_mla_head_block_size, +) + + +def sparse_mla_decode_head_block_size(num_decode_tokens: int) -> int: + """Choose the SM12x sparse MLA head grouping for decode kernels. + + Single-token decode is latency sensitive and does best with one head per + program. Once there are enough query tokens, grouping heads lets the kernel + reuse each dequantized KV row across multiple heads. + """ + + configured_head_block_size = triton_sparse_mla_head_block_size() + if configured_head_block_size is not None: + return configured_head_block_size + if num_decode_tokens <= 4: + return 1 + if num_decode_tokens < 16: + return 2 + return 4 + + +@triton.jit +def _merge_two_subsets_with_sink_kernel( + out0_ptr, + lse0_ptr, + out1_ptr, + lse1_ptr, + sink_ptr, + output_ptr, + stride_out0_t: tl.constexpr, + stride_out0_h: tl.constexpr, + stride_out0_d: tl.constexpr, + stride_lse0_t: tl.constexpr, + stride_lse0_h: tl.constexpr, + stride_out1_t: tl.constexpr, + stride_out1_h: tl.constexpr, + stride_out1_d: tl.constexpr, + stride_lse1_t: tl.constexpr, + stride_lse1_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + lse0 = tl.load(lse0_ptr + token_idx * stride_lse0_t + head_idx * stride_lse0_h) + lse1 = tl.load(lse1_ptr + token_idx * stride_lse1_t + head_idx * stride_lse1_h) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(tl.maximum(lse0, lse1), sink) + + weight0 = tl.exp(lse0 - merge_max) + weight1 = tl.exp(lse1 - merge_max) + weight_sink = tl.exp(sink - merge_max) + denom = weight0 + weight1 + weight_sink + + out0 = tl.load( + out0_ptr + + token_idx * stride_out0_t + + head_idx * stride_out0_h + + offsets * stride_out0_d, + mask=mask, + other=0.0, + ).to(tl.float32) + out1 = tl.load( + out1_ptr + + token_idx * stride_out1_t + + head_idx * stride_out1_h + + offsets * stride_out1_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = (out0 * weight0 + out1 * weight1) / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_two_sparse_mla_subsets_with_sink( + subset0_output: torch.Tensor, + subset0_lse: torch.Tensor, + subset1_output: torch.Tensor, + subset1_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset0_output.shape == subset1_output.shape + assert subset0_output.shape == output.shape + assert subset0_lse.shape == subset1_lse.shape + assert subset0_lse.shape == subset0_output.shape[:2] + assert attn_sink.shape[0] == subset0_output.shape[1] + assert subset0_output.is_cuda + assert subset1_output.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset0_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_two_subsets_with_sink_kernel[grid]( + subset0_output, + subset0_lse, + subset1_output, + subset1_lse, + attn_sink, + output, + subset0_output.stride(0), + subset0_output.stride(1), + subset0_output.stride(2), + subset0_lse.stride(0), + subset0_lse.stride(1), + subset1_output.stride(0), + subset1_output.stride(1), + subset1_output.stride(2), + subset1_lse.stride(0), + subset1_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _merge_single_subset_with_sink_kernel( + subset_output_ptr, + subset_lse_ptr, + sink_ptr, + output_ptr, + stride_subset_t: tl.constexpr, + stride_subset_h: tl.constexpr, + stride_subset_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offsets < head_dim + + subset_lse = tl.load( + subset_lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h + ) + sink = tl.load(sink_ptr + head_idx) + merge_max = tl.maximum(subset_lse, sink) + + subset_weight = tl.exp(subset_lse - merge_max) + sink_weight = tl.exp(sink - merge_max) + denom = subset_weight + sink_weight + subset_output = tl.load( + subset_output_ptr + + token_idx * stride_subset_t + + head_idx * stride_subset_h + + offsets * stride_subset_d, + mask=mask, + other=0.0, + ).to(tl.float32) + merged = subset_output * subset_weight / denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + merged, + mask=mask, + ) + + +def merge_sparse_mla_subset_with_sink( + subset_output: torch.Tensor, + subset_lse: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_output.shape == output.shape + assert subset_lse.shape == subset_output.shape[:2] + assert attn_sink.shape[0] == subset_output.shape[1] + assert subset_output.is_cuda + assert subset_lse.is_cuda + assert attn_sink.is_cuda + assert output.is_cuda + + num_tokens, num_heads, head_dim = subset_output.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _merge_single_subset_with_sink_kernel[grid]( + subset_output, + subset_lse, + attn_sink, + output, + subset_output.stride(0), + subset_output.stride(1), + subset_output.stride(2), + subset_lse.stride(0), + subset_lse.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _build_combined_decode_valid_mask_kernel( + output_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_lens_ptr, + stride_output_t: tl.constexpr, + stride_output_c: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_C: tl.constexpr, +): + token_idx = tl.program_id(0) + offsets = tl.arange(0, BLOCK_C) + candidate_mask = offsets < num_candidates + + topk_lens = tl.load(topk_lens_ptr + token_idx) + swa_lens = tl.load(swa_lens_ptr + token_idx) + is_compressed = offsets < num_compressed_candidates + swa_offsets = offsets - num_compressed_candidates + slot_ids = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + offsets * stride_slot_c, + mask=is_compressed, + other=-1, + ) + valid_compressed = is_compressed & (offsets < topk_lens) & (slot_ids >= 0) + valid_swa = (~is_compressed) & (swa_offsets < swa_lens) + valid = valid_compressed | valid_swa + tl.store( + output_ptr + token_idx * stride_output_t + offsets * stride_output_c, + valid, + mask=candidate_mask, + ) + + +def build_combined_sparse_mla_decode_valid_mask( + output: torch.Tensor, + compressed_slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + swa_lens: torch.Tensor, +) -> None: + """Build `[compressed, SWA]` validity mask for SM12x decode.""" + if compressed_slot_ids.dim() == 3: + assert compressed_slot_ids.shape[1] == 1 + compressed_slot_ids = compressed_slot_ids[:, 0, :] + + assert output.dim() == 2 + assert output.dtype == torch.bool + assert compressed_slot_ids.dim() == 2 + assert output.shape[0] == compressed_slot_ids.shape[0] + assert output.shape[0] == topk_lens.shape[0] + assert output.shape[0] == swa_lens.shape[0] + assert output.shape[1] >= compressed_slot_ids.shape[1] + assert output.is_cuda + assert compressed_slot_ids.is_cuda + assert topk_lens.is_cuda + assert swa_lens.is_cuda + + num_candidates = output.shape[1] + block_c = triton.next_power_of_2(num_candidates) + _build_combined_decode_valid_mask_kernel[(output.shape[0],)]( + output, + compressed_slot_ids, + topk_lens, + swa_lens, + output.stride(0), + output.stride(1), + compressed_slot_ids.stride(0), + compressed_slot_ids.stride(1), + compressed_slot_ids.shape[1], + num_candidates, + BLOCK_C=block_c, + num_warps=4, + ) + + +def matmul_sparse_mla_attention_with_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + score_buffer: torch.Tensor | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + """Compute sink-aware sparse MLA over materialized BF16 KV. + + This path intentionally dequantizes/gathers KV once, computes scores with + batched matrix multiplication, and finishes the sink-aware value reduction + in Triton. It is useful for the SM12x decode path where the direct Triton + kernel otherwise repeats fp8_ds_mla dequantization once per head group. + """ + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert valid_tokens.shape == kv.shape[:2] + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert q.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + q_active = q[:, :active_heads] + num_tokens = q.shape[0] + num_candidates = kv.shape[1] + if score_buffer is None: + score_buffer = torch.empty( + (num_tokens, active_heads, num_candidates), + dtype=torch.float32, + device=q.device, + ) + assert score_buffer.shape == (num_tokens, active_heads, num_candidates) + assert score_buffer.device == q.device + assert score_buffer.dtype in (torch.float32, torch.bfloat16) + if score_buffer.dtype == torch.float32: + q_score = q_active.float() + kv_score = kv.float() + else: + q_score = q_active.to(score_buffer.dtype) + kv_score = kv.to(score_buffer.dtype) + torch.bmm(q_score, kv_score.transpose(1, 2), out=score_buffer) + score_buffer.mul_(scale) + finish_materialized_sparse_mla_scores_with_sink( + score_buffer, + kv, + valid_tokens, + attn_sink, + output, + num_heads=active_heads, + head_block_size=head_block_size, + value_block_size=value_block_size, + candidate_block_size=candidate_block_size, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + running_max = tl.load(attn_sink_ptr + head_offsets, mask=head_mask, other=0.0).to( + tl.float32 + ) + running_denom = tl.full((HEAD_BLOCK,), 1.0, tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_offsets * stride_scores_h + + candidate_idx * stride_scores_c, + mask=head_mask, + other=-float("inf"), + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom[:, None] + tl.store( + output_ptr + + token_idx * stride_out_t + + head_offsets[:, None] * stride_out_h + + dim_offsets[None, :] * stride_out_d, + result, + mask=matrix_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_candidate_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + candidate_offsets = tl.arange(0, BLOCK_K) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + max_score = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + max_score = tl.maximum(max_score, tl.max(scores, axis=0)) + + denom = tl.exp(tl.load(attn_sink_ptr + head_idx).to(tl.float32) - max_score) + acc = tl.zeros((BLOCK_D,), tl.float32) + for candidate_start in range(0, num_candidates, BLOCK_K): + candidates = candidate_start + candidate_offsets + candidate_mask = candidates < num_candidates + is_valid = tl.load( + valid_tokens_ptr + token_idx * stride_valid_t + candidates * stride_valid_c, + mask=candidate_mask, + other=0, + ).to(tl.int1) + scores = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidates * stride_scores_c, + mask=candidate_mask & is_valid, + other=-float("inf"), + ).to(tl.float32) + weights = tl.exp(scores - max_score) + denom += tl.sum(weights, axis=0) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidates[:, None] * stride_kv_c + + dim_offsets[None, :] * stride_kv_d, + mask=(candidate_mask & is_valid)[:, None] & dim_mask[None, :], + other=0.0, + ) + acc += tl.sum(kv.to(tl.float32) * weights[:, None], axis=0) + + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + acc / denom, + mask=dim_mask, + ) + + +@triton.jit +def _finish_materialized_scores_with_sink_value_block_kernel( + scores_ptr, + kv_ptr, + valid_tokens_ptr, + attn_sink_ptr, + output_ptr, + stride_scores_t: tl.constexpr, + stride_scores_h: tl.constexpr, + stride_scores_c: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_valid_t: tl.constexpr, + stride_valid_c: tl.constexpr, + stride_out_t: tl.constexpr, + stride_out_h: tl.constexpr, + stride_out_d: tl.constexpr, + head_dim: tl.constexpr, + num_candidates: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + dim_block_idx = tl.program_id(2) + dim_offsets = dim_block_idx * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = dim_offsets < head_dim + + running_max = tl.load(attn_sink_ptr + head_idx).to(tl.float32) + running_denom = tl.full((), 1.0, tl.float32) + running_acc = tl.zeros((BLOCK_D,), tl.float32) + + for candidate_idx in range(0, num_candidates): + is_valid = tl.load( + valid_tokens_ptr + + token_idx * stride_valid_t + + candidate_idx * stride_valid_c + ) + if is_valid: + score = tl.load( + scores_ptr + + token_idx * stride_scores_t + + head_idx * stride_scores_h + + candidate_idx * stride_scores_c + ).to(tl.float32) + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + dim_offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + result = running_acc / running_denom + tl.store( + output_ptr + + token_idx * stride_out_t + + head_idx * stride_out_h + + dim_offsets * stride_out_d, + result, + mask=dim_mask, + ) + + +def finish_materialized_sparse_mla_scores_with_sink( + scores: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, + num_heads: int | None = None, + head_block_size: int = 1, + value_block_size: int | None = None, + candidate_block_size: int | None = None, +) -> None: + assert scores.dim() == 3 + assert kv.dim() == 3 + assert valid_tokens.shape == kv.shape[:2] + assert scores.shape[0] == kv.shape[0] + assert scores.shape[2] == kv.shape[1] + assert output.shape[0] == kv.shape[0] + assert output.shape[2] == kv.shape[2] + assert scores.dtype in (torch.float32, torch.bfloat16) + assert head_block_size in (1, 2, 4) + if value_block_size is not None: + assert value_block_size in (64, 128, 256, 512) + if candidate_block_size is not None: + assert candidate_block_size in (16, 32, 64, 128) + assert scores.is_cuda and kv.is_cuda and valid_tokens.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= scores.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + + num_tokens, _, num_candidates = scores.shape + head_dim = kv.shape[2] + if candidate_block_size is not None: + block_d = value_block_size if value_block_size is not None else 128 + candidate_grid = (num_tokens, active_heads, triton.cdiv(head_dim, block_d)) + _finish_materialized_scores_with_sink_candidate_block_kernel[candidate_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_K=candidate_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + if value_block_size is not None and value_block_size < head_dim: + value_grid = ( + num_tokens, + active_heads, + triton.cdiv(head_dim, value_block_size), + ) + _finish_materialized_scores_with_sink_value_block_kernel[value_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + head_dim, + num_candidates, + BLOCK_D=value_block_size, + num_warps=4, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + return + + block_d = min(1024, triton.next_power_of_2(head_dim)) + head_grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _finish_materialized_scores_with_sink_kernel[head_grid]( + scores, + kv, + valid_tokens, + attn_sink, + output, + scores.stride(0), + scores.stride(1), + scores.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + valid_tokens.stride(0), + valid_tokens.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + active_heads, + head_dim, + num_candidates, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + if output.shape[1] > active_heads: + output[:, active_heads:].zero_() + + +@triton.jit +def _accumulate_gathered_attention_chunk_kernel( + q_ptr, + kv_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t: tl.constexpr, + stride_kv_c: tl.constexpr, + stride_kv_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HAS_SLOT_IDS: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + is_valid = (candidate_offset + candidate_idx) < valid_len + if HAS_SLOT_IDS: + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = is_valid & (slot_id >= 0) + + if is_valid: + kv = tl.load( + kv_ptr + + token_idx * stride_kv_t + + candidate_idx * stride_kv_c + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_gathered_sparse_mla_attention_chunk( + q: torch.Tensor, + kv: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + slot_ids: torch.Tensor | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv.dim() == 3, f"Expected kv shape [T, K, D], got {kv.shape}" + assert q.shape[0] == kv.shape[0] + assert q.shape[-1] == kv.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + if slot_ids is not None: + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + assert slot_ids.dim() == 2 + assert slot_ids.shape == kv.shape[:2] + assert slot_ids.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = kv.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_gathered_attention_chunk_kernel[grid]( + q, + kv, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv.stride(0), + kv.stride(1), + kv.stride(2), + slot_ids.stride(0) if slot_ids is not None else 0, + slot_ids.stride(1) if slot_ids is not None else 0, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HAS_SLOT_IDS=slot_ids is not None, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_indexed_attention_chunk_kernel( + q_ptr, + kv_flat_ptr, + indices_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_kv_t, + stride_kv_d: tl.constexpr, + stride_indices_t: tl.constexpr, + stride_indices_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + for candidate_idx in range(0, num_candidates): + kv_index = tl.load( + indices_ptr + + token_idx * stride_indices_t + + candidate_idx * stride_indices_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (kv_index >= 0) + + if is_valid: + kv = tl.load( + kv_flat_ptr + + kv_index.to(tl.int64) * stride_kv_t + + offsets * stride_kv_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_indexed_sparse_mla_attention_chunk( + q: torch.Tensor, + kv_flat: torch.Tensor, + indices: torch.Tensor, + lens: torch.Tensor, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert kv_flat.dim() == 2 + assert indices.dim() == 2 + assert indices.shape[0] == q.shape[0] + assert kv_flat.shape[-1] == q.shape[-1] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert q.is_cuda and kv_flat.is_cuda and indices.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = indices.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_indexed_attention_chunk_kernel[grid]( + q, + kv_flat, + indices, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + kv_flat.stride(0), + kv_flat.stride(1), + indices.stride(0), + indices.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_global_slots_attention_chunk_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + slot_ids_ptr, + lens_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + valid_len = tl.load(lens_ptr + token_idx) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = ((candidate_offset + candidate_idx) < valid_len) & (slot_id >= 0) + + if is_valid: + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_global_slots_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + lens: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int = 0, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert lens.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda and slot_ids.is_cuda and lens.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + num_candidates = slot_ids.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_global_slots_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + slot_ids, + lens, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_idx = tl.program_id(1) + offsets = tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + q = tl.load( + q_ptr + token_idx * stride_q_t + head_idx * stride_q_h + offsets * stride_q_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + acc_offset = ( + token_idx * stride_acc_t + head_idx * stride_acc_h + offsets * stride_acc_d + ) + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + running_acc = tl.load(acc_ptr + acc_offset, mask=dim_mask, other=0.0).to(tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = offsets < fp8_dim + rope_mask = (offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv, axis=0) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = running_acc * previous_weight + kv * candidate_weight + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offset, running_max) + tl.store(denom_ptr + state_offset, running_denom) + tl.store(acc_ptr + acc_offset, running_acc, mask=dim_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, num_heads) + _accumulate_fp8ds_paged_attention_chunk_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _accumulate_fp8ds_paged_attention_chunk_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + max_score_ptr, + denom_ptr, + acc_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_candidates, + candidate_offset, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + + state_offsets = token_idx * stride_state_t + head_offsets * stride_state_h + acc_offsets = ( + token_idx * stride_acc_t + + head_offsets[:, None] * stride_acc_h + + dim_offsets[None, :] * stride_acc_d + ) + running_max = tl.load( + max_score_ptr + state_offsets, + mask=head_mask, + other=-float("inf"), + ) + running_denom = tl.load(denom_ptr + state_offsets, mask=head_mask, other=0.0) + running_acc = tl.load(acc_ptr + acc_offsets, mask=matrix_mask, other=0.0).to( + tl.float32 + ) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + tl.store(max_score_ptr + state_offsets, running_max, mask=head_mask) + tl.store(denom_ptr + state_offsets, running_denom, mask=head_mask) + tl.store(acc_ptr + acc_offsets, running_acc, mask=matrix_mask) + + +def accumulate_fp8ds_paged_sparse_mla_attention_chunk_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + scale: float, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + candidate_offset: int, + num_candidates: int, + head_block_size: int = 2, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert max_score.shape[0] == q.shape[0] + assert max_score.shape[1] <= q.shape[1] + assert denom.shape == max_score.shape + assert acc.shape == (*max_score.shape, q.shape[-1]) + assert head_block_size in (1, 2, 4) + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + num_heads = max_score.shape[1] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(num_heads, head_block_size)) + _accumulate_fp8ds_paged_attention_chunk_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + max_score, + denom, + acc, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + num_heads, + head_dim, + num_candidates, + candidate_offset, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_paged_attention_with_sink_multihead_kernel( + q_ptr, + k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + candidate_offset: tl.constexpr, + num_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + + for candidate_idx in range(0, num_candidates): + gather_idx = candidate_offset + candidate_idx + is_valid = gather_idx < gather_len + if is_valid: + pos = start_pos + gather_idx + block_in_seq = pos // cache_block_size + pos_in_block = pos % cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = k_cache_ptr + physical_block.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + block_size: int, + candidate_offset: int, + num_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert k_cache.dtype == torch.uint8 + assert q.is_cuda and k_cache.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_paged_attention_with_sink_multihead_kernel[grid]( + q, + k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + block_size, + token_data_size, + k_cache.stride(0), + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + candidate_offset, + num_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _fp8ds_global_paged_attention_with_sink_multihead_kernel( + q_ptr, + compressed_k_cache_ptr, + slot_ids_ptr, + topk_lens_ptr, + swa_k_cache_ptr, + seq_lens_ptr, + gather_lens_ptr, + block_table_ptr, + sink_ptr, + output_ptr, + stride_q_t: tl.constexpr, + stride_q_h: tl.constexpr, + stride_q_d: tl.constexpr, + stride_slot_t: tl.constexpr, + stride_slot_c: tl.constexpr, + stride_block_table_t, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + compressed_cache_block_size: tl.constexpr, + compressed_block_stride: tl.constexpr, + swa_cache_block_size: tl.constexpr, + swa_block_stride: tl.constexpr, + token_data_size: tl.constexpr, + fp8_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + num_compressed_candidates: tl.constexpr, + num_swa_candidates: tl.constexpr, + scale: tl.constexpr, + HEAD_BLOCK: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + head_block_idx = tl.program_id(1) + head_offsets = head_block_idx * HEAD_BLOCK + tl.arange(0, HEAD_BLOCK) + dim_offsets = tl.arange(0, BLOCK_D) + head_mask = head_offsets < num_heads + dim_mask = dim_offsets < head_dim + matrix_mask = head_mask[:, None] & dim_mask[None, :] + + q = tl.load( + q_ptr + + token_idx * stride_q_t + + head_offsets[:, None] * stride_q_h + + dim_offsets[None, :] * stride_q_d, + mask=matrix_mask, + other=0.0, + ).to(tl.float32) + running_max = tl.full((HEAD_BLOCK,), -float("inf"), tl.float32) + running_denom = tl.zeros((HEAD_BLOCK,), tl.float32) + running_acc = tl.zeros((HEAD_BLOCK, BLOCK_D), tl.float32) + + fp8_mask = dim_offsets < fp8_dim + rope_mask = (dim_offsets >= fp8_dim) & dim_mask + rope_offsets = tl.maximum(dim_offsets - fp8_dim, 0) + topk_len = tl.load(topk_lens_ptr + token_idx) + + for candidate_idx in range(0, num_compressed_candidates): + slot_id = tl.load( + slot_ids_ptr + token_idx * stride_slot_t + candidate_idx * stride_slot_c + ) + is_valid = (candidate_idx < topk_len) & (slot_id >= 0) + if is_valid: + block_idx = slot_id // compressed_cache_block_size + pos_in_block = slot_id % compressed_cache_block_size + cache_block_ptr = ( + compressed_k_cache_ptr + + block_idx.to(tl.int64) * compressed_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + compressed_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + seq_len = tl.load(seq_lens_ptr + token_idx) + gather_len = tl.load(gather_lens_ptr + token_idx) + start_pos = seq_len - gather_len + for candidate_idx in range(0, num_swa_candidates): + is_valid = candidate_idx < gather_len + if is_valid: + pos = start_pos + candidate_idx + block_in_seq = pos // swa_cache_block_size + pos_in_block = pos % swa_cache_block_size + physical_block = tl.load( + block_table_ptr + token_idx * stride_block_table_t + block_in_seq + ) + cache_block_ptr = ( + swa_k_cache_ptr + physical_block.to(tl.int64) * swa_block_stride + ) + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + swa_cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + + x_uint8 = tl.load(token_data_ptr + dim_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + scale_offsets = dim_offsets // quant_block + encoded_scale = tl.load( + token_scale_ptr + scale_offsets, + mask=fp8_mask, + other=127, + ) + dequant_scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * dequant_scale + rope_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + rope = tl.load(rope_ptr + rope_offsets, mask=rope_mask, other=0.0).to( + tl.float32 + ) + kv = tl.where(fp8_mask, x_dequant, rope) + kv = tl.where(dim_mask, kv, 0.0) + + score = tl.sum(q * kv[None, :], axis=1) * scale + next_max = tl.maximum(running_max, score) + previous_weight = tl.exp(running_max - next_max) + candidate_weight = tl.exp(score - next_max) + running_acc = ( + running_acc * previous_weight[:, None] + + kv[None, :] * candidate_weight[:, None] + ) + running_denom = running_denom * previous_weight + candidate_weight + running_max = next_max + + sink = tl.load(sink_ptr + head_offsets, mask=head_mask, other=-float("inf")) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = running_denom * subset_scale + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + final = running_acc * subset_scale[:, None] * inv_total[:, None] + + tl.store( + output_ptr + + token_idx * stride_output_t + + head_offsets[:, None] * stride_output_h + + dim_offsets[None, :] * stride_output_d, + final, + mask=matrix_mask, + ) + + +def fp8ds_global_paged_sparse_mla_attention_with_sink_multihead( + q: torch.Tensor, + compressed_k_cache: torch.Tensor, + slot_ids: torch.Tensor, + topk_lens: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + gather_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, + num_compressed_candidates: int, + num_swa_candidates: int, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + head_block_size: int = 1, + num_heads: int | None = None, +) -> None: + if q.dim() == 4: + assert q.shape[1] == 1 + q = q[:, 0] + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0] + + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + assert q.shape[-1] == 512 + assert slot_ids.dim() == 2 + assert slot_ids.shape[0] == q.shape[0] + assert topk_lens.shape[0] == q.shape[0] + assert seq_lens.shape[0] == q.shape[0] + assert gather_lens.shape[0] == q.shape[0] + assert block_table.shape[0] == q.shape[0] + assert output.shape[0] == q.shape[0] + assert output.shape[2] == q.shape[-1] + assert head_block_size in (1, 2, 4) + assert compressed_k_cache.dtype == torch.uint8 + assert swa_k_cache.dtype == torch.uint8 + assert q.is_cuda and compressed_k_cache.is_cuda and swa_k_cache.is_cuda + assert slot_ids.is_cuda and topk_lens.is_cuda + assert seq_lens.is_cuda and gather_lens.is_cuda and block_table.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + token_fp8_dim = 448 + token_bf16_dim = 64 + token_scale_dim = 8 + quant_block_size = 64 + token_data_size = token_fp8_dim + token_bf16_dim * 2 + + num_tokens, _, head_dim = q.shape + active_heads = num_heads if num_heads is not None else output.shape[1] + assert active_heads <= q.shape[1] + assert active_heads <= output.shape[1] + assert active_heads <= attn_sink.shape[0] + block_d = min(1024, triton.next_power_of_2(head_dim)) + grid = (num_tokens, triton.cdiv(active_heads, head_block_size)) + _fp8ds_global_paged_attention_with_sink_multihead_kernel[grid]( + q, + compressed_k_cache, + slot_ids, + topk_lens, + swa_k_cache, + seq_lens, + gather_lens, + block_table, + attn_sink, + output, + q.stride(0), + q.stride(1), + q.stride(2), + slot_ids.stride(0), + slot_ids.stride(1), + block_table.stride(0), + output.stride(0), + output.stride(1), + output.stride(2), + compressed_block_size, + compressed_k_cache.stride(0), + swa_block_size, + swa_k_cache.stride(0), + token_data_size, + token_fp8_dim, + token_scale_dim, + quant_block_size, + active_heads, + head_dim, + num_compressed_candidates, + num_swa_candidates, + scale, + HEAD_BLOCK=head_block_size, + BLOCK_D=block_d, + num_warps=8, + ) + + +@triton.jit +def _finish_attention_state_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + output_ptr, + lse_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + stride_lse_t: tl.constexpr, + stride_lse_h: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + is_valid = running_denom > 0.0 + inv_denom = tl.where(is_valid, 1.0 / running_denom, 0.0) + subset_lse = tl.where( + is_valid, + running_max + tl.log(running_denom), + -float("inf"), + ) + + acc = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + subset_output = acc * inv_denom + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + subset_output, + mask=dim_mask, + ) + if block_d == 0: + tl.store( + lse_ptr + token_idx * stride_lse_t + head_idx * stride_lse_h, + subset_lse, + ) + + +def finish_gathered_sparse_mla_attention( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape == acc.shape + assert lse.shape == max_score.shape + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert output.dtype == torch.float32 + assert lse.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert output.is_cuda and lse.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_kernel[grid]( + max_score, + denom, + acc, + output, + lse, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + lse.stride(0), + lse.stride(1), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +@triton.jit +def _finish_attention_state_with_sink_kernel( + max_score_ptr, + denom_ptr, + acc_ptr, + sink_ptr, + output_ptr, + stride_state_t: tl.constexpr, + stride_state_h: tl.constexpr, + stride_acc_t: tl.constexpr, + stride_acc_h: tl.constexpr, + stride_acc_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state_offset = token_idx * stride_state_t + head_idx * stride_state_h + running_max = tl.load(max_score_ptr + state_offset) + running_denom = tl.load(denom_ptr + state_offset) + sink = tl.load(sink_ptr + head_idx) + has_tokens = running_denom > 0.0 + has_sink = sink > -float("inf") + valid_max = tl.where(has_tokens, running_max, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(valid_max, valid_sink) + has_any = has_tokens | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_running_max = tl.where(has_tokens, running_max, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + subset_scale = tl.where(has_tokens, tl.exp(safe_running_max - safe_merge_max), 0.0) + subset_weight = running_denom * subset_scale + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = subset_weight + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc_values = tl.load( + acc_ptr + + token_idx * stride_acc_t + + head_idx * stride_acc_h + + offsets * stride_acc_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc_values = tl.where(has_tokens, acc_values, 0.0) + output = acc_values * subset_scale * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +@triton.jit +def _finish_two_attention_states_with_sink_kernel( + max0_ptr, + denom0_ptr, + acc0_ptr, + max1_ptr, + denom1_ptr, + acc1_ptr, + sink_ptr, + output_ptr, + stride_state0_t: tl.constexpr, + stride_state0_h: tl.constexpr, + stride_acc0_t: tl.constexpr, + stride_acc0_h: tl.constexpr, + stride_acc0_d: tl.constexpr, + stride_state1_t: tl.constexpr, + stride_state1_h: tl.constexpr, + stride_acc1_t: tl.constexpr, + stride_acc1_h: tl.constexpr, + stride_acc1_d: tl.constexpr, + stride_output_t: tl.constexpr, + stride_output_h: tl.constexpr, + stride_output_d: tl.constexpr, + num_heads: tl.constexpr, + head_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_head = tl.program_id(0) + block_d = tl.program_id(1) + token_idx = token_head // num_heads + head_idx = token_head - token_idx * num_heads + offsets = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + dim_mask = offsets < head_dim + + state0_offset = token_idx * stride_state0_t + head_idx * stride_state0_h + state1_offset = token_idx * stride_state1_t + head_idx * stride_state1_h + max0 = tl.load(max0_ptr + state0_offset) + denom0 = tl.load(denom0_ptr + state0_offset) + max1 = tl.load(max1_ptr + state1_offset) + denom1 = tl.load(denom1_ptr + state1_offset) + sink = tl.load(sink_ptr + head_idx) + + has0 = denom0 > 0.0 + has1 = denom1 > 0.0 + has_sink = sink > -float("inf") + valid_max0 = tl.where(has0, max0, -float("inf")) + valid_max1 = tl.where(has1, max1, -float("inf")) + valid_sink = tl.where(has_sink, sink, -float("inf")) + merge_max = tl.maximum(tl.maximum(valid_max0, valid_max1), valid_sink) + has_any = has0 | has1 | has_sink + safe_merge_max = tl.where(has_any, merge_max, 0.0) + safe_max0 = tl.where(has0, max0, safe_merge_max) + safe_max1 = tl.where(has1, max1, safe_merge_max) + safe_sink = tl.where(has_sink, sink, safe_merge_max) + scale0 = tl.where(has0, tl.exp(safe_max0 - safe_merge_max), 0.0) + scale1 = tl.where(has1, tl.exp(safe_max1 - safe_merge_max), 0.0) + sink_weight = tl.where(has_sink, tl.exp(safe_sink - safe_merge_max), 0.0) + total_weight = denom0 * scale0 + denom1 * scale1 + sink_weight + inv_total = tl.where(total_weight > 0.0, 1.0 / total_weight, 0.0) + + acc0 = tl.load( + acc0_ptr + + token_idx * stride_acc0_t + + head_idx * stride_acc0_h + + offsets * stride_acc0_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc1 = tl.load( + acc1_ptr + + token_idx * stride_acc1_t + + head_idx * stride_acc1_h + + offsets * stride_acc1_d, + mask=dim_mask, + other=0.0, + ).to(tl.float32) + acc0 = tl.where(has0, acc0, 0.0) + acc1 = tl.where(has1, acc1, 0.0) + output = (acc0 * scale0 + acc1 * scale1) * inv_total + tl.store( + output_ptr + + token_idx * stride_output_t + + head_idx * stride_output_h + + offsets * stride_output_d, + output, + mask=dim_mask, + ) + + +def finish_two_sparse_mla_attention_states_with_sink( + max_score0: torch.Tensor, + denom0: torch.Tensor, + acc0: torch.Tensor, + max_score1: torch.Tensor, + denom1: torch.Tensor, + acc1: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score0.shape == denom0.shape + assert max_score1.shape == denom1.shape + assert max_score0.shape == max_score1.shape + assert acc0.shape == acc1.shape + assert acc0.shape[:2] == max_score0.shape + assert output.shape[0] == acc0.shape[0] + assert output.shape[1] >= acc0.shape[1] + assert output.shape[2] == acc0.shape[2] + assert attn_sink.shape[0] >= acc0.shape[1] + assert max_score0.dtype == torch.float32 + assert denom0.dtype == torch.float32 + assert acc0.dtype == torch.float32 + assert max_score1.dtype == torch.float32 + assert denom1.dtype == torch.float32 + assert acc1.dtype == torch.float32 + assert max_score0.is_cuda and denom0.is_cuda and acc0.is_cuda + assert max_score1.is_cuda and denom1.is_cuda and acc1.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc0.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_two_attention_states_with_sink_kernel[grid]( + max_score0, + denom0, + acc0, + max_score1, + denom1, + acc1, + attn_sink, + output, + max_score0.stride(0), + max_score0.stride(1), + acc0.stride(0), + acc0.stride(1), + acc0.stride(2), + max_score1.stride(0), + max_score1.stride(1), + acc1.stride(0), + acc1.stride(1), + acc1.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) + + +def finish_sparse_mla_attention_with_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert max_score.shape == denom.shape + assert acc.shape[:2] == max_score.shape + assert output.shape[0] == acc.shape[0] + assert output.shape[1] >= acc.shape[1] + assert output.shape[2] == acc.shape[2] + assert attn_sink.shape[0] >= acc.shape[1] + assert max_score.dtype == torch.float32 + assert denom.dtype == torch.float32 + assert acc.dtype == torch.float32 + assert max_score.is_cuda and denom.is_cuda and acc.is_cuda + assert attn_sink.is_cuda and output.is_cuda + + num_tokens, num_heads, head_dim = acc.shape + block_d = min(128, triton.next_power_of_2(head_dim)) + grid = (num_tokens * num_heads, triton.cdiv(head_dim, block_d)) + _finish_attention_state_with_sink_kernel[grid]( + max_score, + denom, + acc, + attn_sink, + output, + max_score.stride(0), + max_score.stride(1), + acc.stride(0), + acc.stride(1), + acc.stride(2), + output.stride(0), + output.stride(1), + output.stride(2), + num_heads, + head_dim, + BLOCK_D=block_d, + num_warps=4, + ) diff --git a/vllm/v1/attention/backends/mla/sparse_mla_reference.py b/vllm/v1/attention/backends/mla/sparse_mla_reference.py new file mode 100644 index 000000000000..203b64188202 --- /dev/null +++ b/vllm/v1/attention/backends/mla/sparse_mla_reference.py @@ -0,0 +1,242 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Reference sparse MLA attention helpers. + +The helpers in this module intentionally use PyTorch tensor operations. They +are the correctness-first contract for portable sparse MLA fallbacks and tests; +optimized Triton/CUDA kernels should preserve these semantics. +""" + +import torch + + +def new_reference_attention_state( + q: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + if q.dim() == 4: + q_bhd = q[:, 0, :, :].float() + else: + assert q.dim() == 3, f"Expected q shape [T, H, D], got {q.shape}" + q_bhd = q.float() + + num_tokens = q_bhd.shape[0] + num_heads = q_bhd.shape[1] + head_dim = q_bhd.shape[2] + max_score = torch.full( + (num_tokens, num_heads), + float("-inf"), + dtype=torch.float32, + device=q.device, + ) + denom = torch.zeros_like(max_score) + acc = torch.zeros( + (num_tokens, num_heads, head_dim), + dtype=torch.float32, + device=q.device, + ) + return q_bhd, max_score, denom, acc + + +def accumulate_reference_attention_chunk( + q_bhd: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + kv_btd = kv.float() + kv_btd = torch.where( + valid_tokens[:, :, None], + kv_btd, + torch.zeros((), dtype=kv_btd.dtype, device=kv_btd.device), + ) + scores = torch.einsum("bhd,btd->bht", q_bhd, kv_btd) * scale + scores = scores.masked_fill(~valid_tokens[:, None, :], float("-inf")) + + chunk_max = scores.amax(dim=-1) + next_max = torch.maximum(max_score, chunk_max) + + previous_scale = torch.exp(max_score - next_max) + previous_scale = torch.nan_to_num(previous_scale) + weights = torch.exp(scores - next_max[:, :, None]) + weights = torch.where( + valid_tokens[:, None, :], + weights, + torch.zeros((), dtype=weights.dtype, device=weights.device), + ) + weights = torch.nan_to_num(weights) + + acc = acc * previous_scale[:, :, None] + denom = denom * previous_scale + acc = acc + torch.einsum("bht,btd->bhd", weights, kv_btd) + denom = denom + weights.sum(dim=-1) + return next_max, denom, acc + + +def finish_reference_attention_no_sink( + max_score: torch.Tensor, + denom: torch.Tensor, + acc: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + valid = denom > 0 + safe_denom = torch.where(valid, denom, torch.ones_like(denom)) + subset_output = acc / safe_denom[:, :, None] + subset_output = torch.where( + valid[:, :, None], + subset_output, + torch.zeros((), dtype=subset_output.dtype, device=subset_output.device), + ) + subset_lse = torch.where( + valid, + max_score + torch.log(safe_denom), + torch.full_like(max_score, float("-inf")), + ) + return subset_output, subset_lse + + +def reference_attention_no_sink( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor]: + q_bhd, max_score, denom, acc = new_reference_attention_state(q) + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=kv, + valid_tokens=valid_tokens, + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + return finish_reference_attention_no_sink(max_score, denom, acc) + + +def merge_reference_attention_with_sink( + subset_outputs: list[torch.Tensor], + subset_lses: list[torch.Tensor], + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + assert subset_outputs, "At least one attention subset is required" + assert len(subset_outputs) == len(subset_lses) + + sink = attn_sink[None, :].float() + merge_max = sink + for subset_lse in subset_lses: + merge_max = torch.maximum(merge_max, subset_lse) + + safe_merge_max = torch.where( + torch.isfinite(merge_max), merge_max, torch.zeros_like(merge_max) + ) + merged_acc = torch.zeros_like(subset_outputs[0], dtype=torch.float32) + sink_weight = torch.exp(sink - safe_merge_max) + sink_weight = torch.nan_to_num(sink_weight) + merged_denom = sink_weight + for subset_output, subset_lse in zip(subset_outputs, subset_lses): + subset_weight = torch.exp(subset_lse - safe_merge_max) + subset_weight = torch.nan_to_num(subset_weight) + merged_acc = merged_acc + subset_output.float() * subset_weight[:, :, None] + merged_denom = merged_denom + subset_weight + + safe_denom = torch.where( + merged_denom > 0, merged_denom, torch.ones_like(merged_denom) + ) + reference_output = merged_acc / safe_denom[:, :, None] + reference_output = torch.where( + (merged_denom > 0)[:, :, None], + reference_output, + torch.zeros((), dtype=reference_output.dtype, device=reference_output.device), + ) + output.copy_(reference_output.to(dtype=output.dtype)) + + +def sink_aware_reference_attention( + q: torch.Tensor, + kv: torch.Tensor, + valid_tokens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, +) -> None: + subset_output, subset_lse = reference_attention_no_sink( + q=q, + kv=kv, + valid_tokens=valid_tokens, + scale=scale, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=attn_sink, + output=output, + ) + + +def reference_sparse_mla_prefill( + q: torch.Tensor, + kv: torch.Tensor, + combined_indices: torch.Tensor, + combined_lens: torch.Tensor, + scale: float, + attn_sink: torch.Tensor, + output: torch.Tensor, + topk_chunk_size: int, + query_chunk_size: int, +) -> None: + kv_flat = kv.reshape(-1, q.shape[-1]) + topk_chunk_size = min(combined_indices.shape[-1], topk_chunk_size) + query_chunk_size = min(q.shape[0], query_chunk_size) + + for token_start in range(0, q.shape[0], query_chunk_size): + token_end = min(token_start + query_chunk_size, q.shape[0]) + q_chunk = q[token_start:token_end] + lens_chunk = combined_lens[token_start:token_end] + indices_chunk_full = combined_indices[token_start:token_end] + q_bhd, max_score, denom, acc = new_reference_attention_state(q_chunk) + + for index_start in range(0, combined_indices.shape[-1], topk_chunk_size): + index_end = min( + index_start + topk_chunk_size, + combined_indices.shape[-1], + ) + indices_chunk = indices_chunk_full[:, index_start:index_end] + index_offsets = torch.arange( + index_start, + index_end, + device=q.device, + ) + valid_tokens = ( + (index_offsets[None, :] < lens_chunk[:, None]) + & (indices_chunk >= 0) + ) + safe_indices = torch.where( + valid_tokens, + indices_chunk, + torch.zeros((), dtype=indices_chunk.dtype, device=q.device), + ).long() + gathered_kv = kv_flat[safe_indices] + max_score, denom, acc = accumulate_reference_attention_chunk( + q_bhd=q_bhd, + kv=gathered_kv, + valid_tokens=valid_tokens, + max_score=max_score, + denom=denom, + acc=acc, + scale=scale, + ) + + subset_output, subset_lse = finish_reference_attention_no_sink( + max_score, + denom, + acc, + ) + merge_reference_attention_with_sink( + subset_outputs=[subset_output], + subset_lses=[subset_lse], + attn_sink=attn_sink, + output=output[token_start:token_end], + ) diff --git a/vllm/v1/attention/backends/mla/sparse_swa.py b/vllm/v1/attention/backends/mla/sparse_swa.py index 28564e6a97d3..7689cf9e155a 100644 --- a/vllm/v1/attention/backends/mla/sparse_swa.py +++ b/vllm/v1/attention/backends/mla/sparse_swa.py @@ -16,9 +16,15 @@ CommonAttentionMetadata, MultipleOf, ) +from vllm.v1.attention.backends.mla.sparse_mla_env import ( + is_triton_sparse_mla_enabled, + is_triton_sparse_mla_enabled_for_platform, + triton_sparse_mla_cudagraphs_allowed, +) from vllm.v1.attention.backends.utils import split_decodes_and_prefills from vllm.v1.attention.ops.flashmla import FlashMLASchedMeta, get_mla_metadata from vllm.v1.kv_cache_interface import ( + AttentionSpec, KVCacheSpec, MLAAttentionSpec, SlidingWindowMLASpec, @@ -162,6 +168,8 @@ class DeepseekSparseSWAMetadata: # Pre-computed prefill metadata shared across all DeepseekV4 attention layers. prefill_seq_lens: torch.Tensor | None = None prefill_gather_lens: torch.Tensor | None = None + prefill_seq_lens_cpu: torch.Tensor | None = None + prefill_gather_lens_cpu: torch.Tensor | None = None # Per-layer-type FlashMLA tile-scheduler metadata. One FlashMLASchedMeta # per present DeepseekV4 layer type, shared across all ~60 layers of that type @@ -195,6 +203,20 @@ class DeepseekSparseSWAMetadataBuilder(AttentionMetadataBuilder): reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if ( + getattr(kv_cache_spec, "model_version", None) == "deepseek_v4" + and is_triton_sparse_mla_enabled_for_platform() + and not triton_sparse_mla_cudagraphs_allowed(vllm_config) + ): + return AttentionCGSupport.NEVER + return cls._cudagraph_support + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert isinstance(self.kv_cache_spec, SlidingWindowMLASpec | MLAAttentionSpec) @@ -313,6 +335,8 @@ def build( num_prefills, seq_lens, query_start_loc, + query_start_loc_cpu, + common_attn_metadata.seq_lens_cpu_upper_bound, ) # Per-layer-type tile-scheduler plan holders. Empty FlashMLASchedMeta @@ -363,6 +387,8 @@ def build_tile_scheduler( } if num_decode_tokens == 0 or current_platform.is_rocm(): return out + if is_triton_sparse_mla_enabled(self.device): + return out for layer_type in self._layer_types: # get_mla_metadata() is the official FlashMLA entry point that # returns a fresh empty FlashMLASchedMeta; using it keeps this @@ -377,6 +403,8 @@ def _build_deepseek_v4_metadata( num_prefills: int, seq_lens: torch.Tensor, query_start_loc: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + seq_lens_cpu_upper_bound: torch.Tensor | None, ) -> dict[str, torch.Tensor | None]: """Pre-compute DeepseekV4 prefill metadata during the metadata build phase. @@ -403,8 +431,27 @@ def _build_deepseek_v4_metadata( BLOCK_SIZE=triton.next_power_of_2(num_prefills), ) + assert seq_lens_cpu_upper_bound is not None + seq_lens_cpu = seq_lens_cpu_upper_bound + prefill_seq_lens_cpu = seq_lens_cpu[ + num_decodes : num_decodes + num_prefills + ] + query_lens_cpu = ( + query_start_loc_cpu[ + num_decodes + 1 : num_decodes + num_prefills + 1 + ] + - query_start_loc_cpu[num_decodes : num_decodes + num_prefills] + ) + prefix_lens_cpu = prefill_seq_lens_cpu - query_lens_cpu + prefill_gather_lens_cpu = query_lens_cpu + torch.minimum( + prefix_lens_cpu, + torch.full_like(prefix_lens_cpu, self.window_size - 1), + ) + result["prefill_seq_lens"] = seq_lens[num_decodes:] result["prefill_gather_lens"] = pfx_gather_lens + result["prefill_seq_lens_cpu"] = prefill_seq_lens_cpu + result["prefill_gather_lens_cpu"] = prefill_gather_lens_cpu return result diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py index 959a79f292a5..da04498f384f 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/__init__.py @@ -5,7 +5,10 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + dequantize_combined_sparse_mla_decode_kv, + dequantize_global_slots_k_cache, quantize_and_insert_k_cache, + sparse_prefill_combined_topk_size, ) from .fused_indexer_q import MXFP4_BLOCK_SIZE, fused_indexer_q_rope_quant from .fused_inv_rope_fp8_quant import fused_inv_rope_fp8_quant @@ -16,8 +19,11 @@ "combine_topk_swa_indices", "compute_global_topk_indices_and_lens", "dequantize_and_gather_k_cache", + "dequantize_combined_sparse_mla_decode_kv", + "dequantize_global_slots_k_cache", "fused_indexer_q_rope_quant", "fused_inv_rope_fp8_quant", "fused_q_kv_rmsnorm", "quantize_and_insert_k_cache", + "sparse_prefill_combined_topk_size", ] diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py index 69d20c107e11..33cfe699236f 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py @@ -349,12 +349,166 @@ def dequantize_and_gather_k_cache( ) +@triton.jit +def _dequantize_global_slots_k_kernel( + out_ptr, + out_stride_token, + out_stride_slot, + k_cache_ptr, + slot_ids_ptr, + slot_ids_stride_token, + slot_ids_stride_slot, + cache_block_size: tl.constexpr, + token_data_size: tl.constexpr, + block_stride: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + output_dim: tl.constexpr, + BLOCK_D: tl.constexpr, +): + token_idx = tl.program_id(0) + topk_idx = tl.program_id(1) + + slot_id = tl.load( + slot_ids_ptr + + token_idx * slot_ids_stride_token + + topk_idx * slot_ids_stride_slot + ) + offsets = tl.arange(0, BLOCK_D) + output_row = out_ptr + token_idx * out_stride_token + topk_idx * out_stride_slot + + if slot_id < 0: + tl.store( + output_row + offsets, + tl.zeros((BLOCK_D,), dtype=tl.float32).to(tl.bfloat16), + mask=offsets < output_dim, + ) + return + + block_idx = slot_id // cache_block_size + pos_in_block = slot_id % cache_block_size + cache_block_ptr = k_cache_ptr + block_idx.to(tl.int64) * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + cache_block_size * token_data_size + pos_in_block * scale_dim + ) + + fp8_offsets = tl.arange(0, 512) + fp8_mask = fp8_offsets < fp8_dim + x_uint8 = tl.load(token_data_ptr + fp8_offsets, mask=fp8_mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + + scale_offsets = fp8_offsets // quant_block + encoded_scale = tl.load(token_scale_ptr + scale_offsets, mask=fp8_mask, other=127) + scale = tl.exp2(encoded_scale.to(tl.float32) - 127.0) + x_dequant = x_float * scale + tl.store(output_row + fp8_offsets, x_dequant.to(tl.bfloat16), mask=fp8_mask) + + bf16_offsets = tl.arange(0, 64) + bf16_cache_ptr = (token_data_ptr + fp8_dim).to(tl.pointer_type(tl.bfloat16)) + bf16_vals = tl.load(bf16_cache_ptr + bf16_offsets, mask=bf16_offsets < bf16_dim) + tl.store( + output_row + fp8_dim + bf16_offsets, + bf16_vals, + mask=bf16_offsets < bf16_dim, + ) + + +def dequantize_global_slots_k_cache( + out: torch.Tensor, + k_cache: torch.Tensor, + slot_ids: torch.Tensor, + block_size: int, +) -> None: + """Dequantize fp8_ds_mla cache rows addressed by physical global slot ids.""" + if slot_ids.dim() == 3: + assert slot_ids.shape[1] == 1 + slot_ids = slot_ids[:, 0, :] + assert slot_ids.dim() == 2, ( + f"slot_ids must be [num_tokens, topk], got {slot_ids.shape}" + ) + assert out.shape[:2] == slot_ids.shape + assert out.shape[-1] == 512 + assert out.dtype == torch.bfloat16 + assert k_cache.dtype == torch.uint8 + + TOKEN_FP8_DIM = 448 + TOKEN_BF16_DIM = 64 + TOKEN_SCALE_DIM = 8 + QUANT_BLOCK_SIZE = 64 + TOKEN_DATA_SIZE = TOKEN_FP8_DIM + TOKEN_BF16_DIM * 2 + + grid = slot_ids.shape + _dequantize_global_slots_k_kernel[grid]( + out, + out.stride(0), + out.stride(1), + k_cache, + slot_ids, + slot_ids.stride(0), + slot_ids.stride(1), + cache_block_size=block_size, + token_data_size=TOKEN_DATA_SIZE, + block_stride=k_cache.stride(0), + fp8_dim=TOKEN_FP8_DIM, + bf16_dim=TOKEN_BF16_DIM, + scale_dim=TOKEN_SCALE_DIM, + quant_block=QUANT_BLOCK_SIZE, + output_dim=512, + BLOCK_D=triton.next_power_of_2(512), + ) + + +def dequantize_combined_sparse_mla_decode_kv( + combined_kv: torch.Tensor, + compressed_k_cache: torch.Tensor, + compressed_slot_ids: torch.Tensor, + compressed_block_size: int, + swa_k_cache: torch.Tensor, + seq_lens: torch.Tensor, + swa_lens: torch.Tensor, + block_table: torch.Tensor, + swa_block_size: int, +) -> None: + """Fill `[compressed, SWA]` decode candidates into one output buffer.""" + assert combined_kv.dim() == 3 + compressed_topk = compressed_slot_ids.shape[-1] + assert combined_kv.shape[0] == compressed_slot_ids.shape[0] + assert combined_kv.shape[-1] == 512 + assert combined_kv.dtype == torch.bfloat16 + assert combined_kv.shape[1] >= compressed_topk + + dequantize_global_slots_k_cache( + combined_kv[:, :compressed_topk], + compressed_k_cache, + compressed_slot_ids, + compressed_block_size, + ) + swa_out = combined_kv[:, compressed_topk:] + if swa_out.shape[1] == 0: + return + dequantize_and_gather_k_cache( + swa_out, + swa_k_cache, + seq_lens=seq_lens, + gather_lens=swa_lens, + block_table=block_table, + block_size=swa_block_size, + offset=0, + ) + + def compute_global_topk_indices_and_lens( topk_indices: torch.Tensor, token_to_req_indices: torch.Tensor, block_table: torch.Tensor, block_size: int, is_valid_token: torch.Tensor, + global_topk_indices: torch.Tensor | None = None, + topk_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Map local topk indices to global KV cache slots and count valid entries. @@ -364,8 +518,20 @@ def compute_global_topk_indices_and_lens( 3. Masking padding tokens to length 0 """ num_tokens = topk_indices.shape[0] - global_topk_indices = torch.empty_like(topk_indices) - topk_lens = torch.empty(num_tokens, dtype=torch.int32, device=topk_indices.device) + if global_topk_indices is None: + global_topk_indices = torch.empty_like(topk_indices) + else: + assert global_topk_indices.shape == topk_indices.shape + assert global_topk_indices.dtype == topk_indices.dtype + assert global_topk_indices.device == topk_indices.device + if topk_lens is None: + topk_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert topk_lens.shape == (num_tokens,) + assert topk_lens.dtype == torch.int32 + assert topk_lens.device == topk_indices.device _compute_global_topk_indices_and_lens_kernel[(num_tokens,)]( global_topk_indices, global_topk_indices.stride(0), @@ -412,7 +578,7 @@ def _compute_global_topk_indices_and_lens_kernel( mask=mask, other=-1, ) - is_valid = local_idx >= 0 + is_valid = (local_idx >= 0) & is_valid_token block_indices = local_idx // block_size block_numbers = tl.load( @@ -442,6 +608,14 @@ def _compute_global_topk_indices_and_lens_kernel( _SPARSE_PREFILL_TOPK_ALIGNMENT = 128 +def sparse_prefill_combined_topk_size(topk: int, window_size: int) -> int: + return ( + (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) + // _SPARSE_PREFILL_TOPK_ALIGNMENT + * _SPARSE_PREFILL_TOPK_ALIGNMENT + ) + + def combine_topk_swa_indices( topk_indices: torch.Tensor, query_start_loc: torch.Tensor, @@ -452,23 +626,35 @@ def combine_topk_swa_indices( topk: int, M: int, N: int, + combined_indices: torch.Tensor | None = None, + combined_lens: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: num_tokens = topk_indices.shape[0] num_reqs = seq_lens.shape[0] - combined_topk = ( - (topk + window_size + _SPARSE_PREFILL_TOPK_ALIGNMENT - 1) - // _SPARSE_PREFILL_TOPK_ALIGNMENT - * _SPARSE_PREFILL_TOPK_ALIGNMENT - ) - combined_indices = torch.full( - (num_tokens, combined_topk), - fill_value=-1, - dtype=torch.int32, - device=topk_indices.device, - ) - combined_lens = torch.empty( - num_tokens, dtype=torch.int32, device=topk_indices.device - ) + combined_topk = sparse_prefill_combined_topk_size(topk, window_size) + if combined_indices is None: + combined_indices = torch.full( + (num_tokens, combined_topk), + fill_value=-1, + dtype=torch.int32, + device=topk_indices.device, + ) + else: + assert combined_indices.shape[0] >= num_tokens + assert combined_indices.shape[1] >= combined_topk + assert combined_indices.dtype == torch.int32 + assert combined_indices.device == topk_indices.device + combined_indices = combined_indices[:num_tokens, :combined_topk] + combined_indices.fill_(-1) + if combined_lens is None: + combined_lens = torch.empty( + num_tokens, dtype=torch.int32, device=topk_indices.device + ) + else: + assert combined_lens.shape[0] >= num_tokens + assert combined_lens.dtype == torch.int32 + assert combined_lens.device == topk_indices.device + combined_lens = combined_lens[:num_tokens] NUM_WORKERS = 128 _combine_topk_swa_indices_kernel[(num_reqs, NUM_WORKERS)]( diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py new file mode 100644 index 000000000000..652dc7be7907 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fp8_einsum.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""SM12x Triton FP8 einsum kernels for DeepSeek V4.""" + +import torch + +from vllm.triton_utils import tl, triton + + +def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: + exp_bits = scale.view(torch.uint8).to(torch.int32) + fp32_bits = exp_bits << 23 + return fp32_bits.view(torch.float32) + + +@triton.jit +def _deepseek_v4_sm12x_fp8_einsum_kernel( + a_ptr, + a_scale_ptr, + b_ptr, + b_scale_ptr, + out_ptr, + num_tokens: tl.constexpr, + num_groups: tl.constexpr, + out_rank: tl.constexpr, + hidden_size: tl.constexpr, + a_stride_token: tl.constexpr, + a_stride_group: tl.constexpr, + a_stride_hidden: tl.constexpr, + a_scale_stride_token: tl.constexpr, + a_scale_stride_group: tl.constexpr, + a_scale_stride_hidden: tl.constexpr, + b_stride_group: tl.constexpr, + b_stride_out: tl.constexpr, + b_stride_hidden: tl.constexpr, + b_scale_stride_group: tl.constexpr, + b_scale_stride_out: tl.constexpr, + b_scale_stride_hidden: tl.constexpr, + out_stride_token: tl.constexpr, + out_stride_group: tl.constexpr, + out_stride_rank: tl.constexpr, + BLOCK_TOKENS: tl.constexpr, + BLOCK_OUT: tl.constexpr, + BLOCK_HIDDEN: tl.constexpr, +) -> None: + token_block = tl.program_id(0) + out_block = tl.program_id(1) + group = tl.program_id(2) + + token_offsets = token_block * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS) + out_offsets = out_block * BLOCK_OUT + tl.arange(0, BLOCK_OUT) + hidden_offsets = tl.arange(0, BLOCK_HIDDEN) + accum = tl.zeros((BLOCK_TOKENS, BLOCK_OUT), dtype=tl.float32) + + for hidden_start in range(0, hidden_size, BLOCK_HIDDEN): + hidden = hidden_start + hidden_offsets + a = tl.load( + a_ptr + + token_offsets[:, None] * a_stride_token + + group * a_stride_group + + hidden[None, :] * a_stride_hidden, + mask=(token_offsets[:, None] < num_tokens) + & (hidden[None, :] < hidden_size), + other=0.0, + ) + b = tl.load( + b_ptr + + group * b_stride_group + + out_offsets[None, :] * b_stride_out + + hidden[:, None] * b_stride_hidden, + mask=(out_offsets[None, :] < out_rank) & (hidden[:, None] < hidden_size), + other=0.0, + ) + raw = tl.dot(a, b, out_dtype=tl.float32) + hidden_scale_block = hidden_start // BLOCK_HIDDEN + a_scale = tl.load( + a_scale_ptr + + token_offsets * a_scale_stride_token + + group * a_scale_stride_group + + hidden_scale_block * a_scale_stride_hidden, + mask=token_offsets < num_tokens, + other=0.0, + ) + b_scale = tl.load( + b_scale_ptr + + group * b_scale_stride_group + + (out_offsets // 128) * b_scale_stride_out + + hidden_scale_block * b_scale_stride_hidden, + mask=out_offsets < out_rank, + other=0.0, + ) + accum += raw * a_scale[:, None] * b_scale[None, :] + + tl.store( + out_ptr + + token_offsets[:, None] * out_stride_token + + group * out_stride_group + + out_offsets[None, :] * out_stride_rank, + accum, + mask=(token_offsets[:, None] < num_tokens) & (out_offsets[None, :] < out_rank), + ) + + +def deepseek_v4_sm12x_fp8_einsum( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, +) -> None: + """Compute ``bhr,hdr->bhd`` with FP32 block scales on SM12x. + + ``a`` is the transposed output of ``fused_inv_rope_fp8_quant`` with shape + ``[tokens, groups, hidden]``. ``b`` is ``wo_a`` reshaped to + ``[groups, out_rank, hidden]``. + """ + num_tokens, num_groups, hidden_size = a.shape + b_groups, out_rank, b_hidden_size = b.shape + assert b_groups == num_groups + assert b_hidden_size == hidden_size + assert out.shape == (num_tokens, num_groups, out_rank) + assert hidden_size % 128 == 0 + assert out_rank % 128 == 0 + assert a.dtype == torch.float8_e4m3fn + assert b.dtype == torch.float8_e4m3fn + e8m0_dtype = getattr(torch, "float8_e8m0fnu", None) + if a_scale.dtype == e8m0_dtype: + a_scale = _upcast_e8m0_to_fp32(a_scale) + if b_scale.dtype == e8m0_dtype: + b_scale = _upcast_e8m0_to_fp32(b_scale) + assert a_scale.dtype == torch.float32 + assert b_scale.dtype == torch.float32 + + if num_tokens == 0: + return + + block_tokens = 16 + block_out = 128 + block_hidden = 128 + grid = ( + triton.cdiv(num_tokens, block_tokens), + triton.cdiv(out_rank, block_out), + num_groups, + ) + _deepseek_v4_sm12x_fp8_einsum_kernel[grid]( + a, + a_scale, + b, + b_scale, + out, + num_tokens, + num_groups, + out_rank, + hidden_size, + a.stride(0), + a.stride(1), + a.stride(2), + a_scale.stride(0), + a_scale.stride(1), + a_scale.stride(2), + b.stride(0), + b.stride(1), + b.stride(2), + b_scale.stride(0), + b_scale.stride(1), + b_scale.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + BLOCK_TOKENS=block_tokens, + BLOCK_OUT=block_out, + BLOCK_HIDDEN=block_hidden, + num_warps=4, + num_stages=3, + ) diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 65993e804153..3f2ac7cdedb7 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -250,6 +250,17 @@ def remove_skipped_blocks( for manager in self.single_type_managers: manager.remove_skipped_blocks(request_id, total_computed_tokens) + def release_protected_prompt_blocks( + self, target_free_blocks: int | None = None + ) -> None: + for manager in self.single_type_managers: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + manager.release_protected_prompt_blocks(target_free_blocks) + def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]: """ Get the blocks for the request. @@ -475,6 +486,8 @@ def verify_and_split_kv_cache_groups(self) -> None: # block cache hit yet. block_sizes = [spec.block_size for spec, _, _ in attention_groups] self.lcm_block_size = lcm(*block_sizes) + for manager in self.single_type_managers: + manager.cache_alignment_tokens = self.lcm_block_size # Attention-group indices (into ``self.attention_groups``) that # contain at least one EAGLE/MTP KV cache group. diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 431776870cf4..2f3c5abf3066 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -345,7 +345,7 @@ def allocate_slots( num_tokens_main_model=full_num_tokens, apply_admission_cap=True, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks(num_blocks_to_allocate): return None num_tokens_main_model = total_computed_tokens + num_new_tokens @@ -373,7 +373,7 @@ def allocate_slots( num_tokens_main_model=num_tokens_main_model, ) - if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + if not self._has_enough_free_blocks(num_blocks_to_allocate): # Cannot allocate new blocks return None @@ -446,6 +446,12 @@ def evict_blocks(self, block_ids: set[int]) -> None: """ self.block_pool.evict_blocks(block_ids) + def _has_enough_free_blocks(self, num_blocks: int) -> bool: + if num_blocks <= self.block_pool.get_num_free_blocks(): + return True + self.coordinator.release_protected_prompt_blocks(num_blocks) + return num_blocks <= self.block_pool.get_num_free_blocks() + def reset_prefix_cache(self) -> bool: """Reset prefix cache. This function may be used in RLHF flows to invalidate prefix caching after the weights are updated, @@ -455,6 +461,7 @@ def reset_prefix_cache(self) -> bool: bool: True if the prefix cache is successfully reset, False otherwise. """ + self.coordinator.release_protected_prompt_blocks() if not self.block_pool.reset_prefix_cache(): return False if self.log_stats: diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index e8d3a6f75688..651f52f6b19d 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, deque from collections.abc import Sequence from vllm.utils.math_utils import cdiv @@ -42,6 +42,7 @@ def __init__( dcp_world_size: int = 1, pcp_world_size: int = 1, max_admission_blocks_per_request: int | None = None, + max_model_len: int | None = None, ) -> None: """ Initializes the SingleTypeKVCacheManager. @@ -65,6 +66,8 @@ def __init__( self.block_pool = block_pool self.enable_caching = enable_caching self._max_admission_blocks_per_request = max_admission_blocks_per_request + self.max_model_len = max_model_len + self.cache_alignment_tokens = self.block_size self.new_block_ids: list[int] = [] # Mapping from request ID to blocks to track the blocks allocated @@ -80,6 +83,8 @@ def __init__( self.kv_cache_group_id = kv_cache_group_id self._null_block = block_pool.null_block + self._protected_prompt_block_ids: set[int] = set() + self._protected_prompt_block_queue: deque[int] = deque() @classmethod def _get_num_evictable_blocks(cls, blocks: Sequence[KVCacheBlock]): @@ -274,6 +279,70 @@ def take_new_block_ids(self) -> list[int]: self.new_block_ids = [] return ids + def _max_protected_prompt_blocks(self) -> int | None: + if self.max_model_len is None: + return None + return 2 * cdiv(max(1, self.max_model_len), self.block_size) + + def _protect_prompt_blocks(self, blocks: Sequence[KVCacheBlock]) -> None: + if not self.enable_caching: + return + + protected: list[KVCacheBlock] = [] + for block in blocks: + if ( + block.is_null + or block.block_hash is None + or block.block_id in self._protected_prompt_block_ids + ): + continue + protected.append(block) + self._protected_prompt_block_ids.add(block.block_id) + self._protected_prompt_block_queue.append(block.block_id) + + if not protected: + return + + # Keep an extra reference for prompt blocks that must survive after + # their request releases its normal runtime reference. Later request + # reuse increments/decrements the runtime reference as usual. + self.block_pool.touch(protected) + self._trim_protected_prompt_blocks() + + def _trim_protected_prompt_blocks(self) -> None: + max_blocks = self._max_protected_prompt_blocks() + if max_blocks is None: + return + + while len(self._protected_prompt_block_ids) > max_blocks: + if not self._release_one_protected_prompt_block(): + return + + def _release_one_protected_prompt_block(self) -> bool: + while self._protected_prompt_block_queue: + block_id = self._protected_prompt_block_queue.popleft() + if block_id not in self._protected_prompt_block_ids: + continue + + self._protected_prompt_block_ids.remove(block_id) + block = self.block_pool.blocks[block_id] + if block.ref_cnt > 0: + self.block_pool.free_blocks([block]) + return True + return False + + def release_protected_prompt_blocks( + self, target_free_blocks: int | None = None + ) -> None: + while self._protected_prompt_block_ids: + if ( + target_free_blocks is not None + and self.block_pool.get_num_free_blocks() >= target_free_blocks + ): + return + if not self._release_one_protected_prompt_block(): + return + def cache_blocks(self, request: Request, num_tokens: int) -> None: """ Cache the blocks for the request. @@ -504,6 +573,54 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return num_common_blocks +class MLAAttentionManager(FullAttentionManager): + """KV cache manager for DeepSeek V4 compressed MLA cache.""" + + def _should_protect_prompt_blocks(self) -> bool: + return ( + self.kv_cache_spec.model_version == "deepseek_v4" + or self.kv_cache_spec.cache_dtype_str == "fp8_ds_mla" + or self.kv_cache_spec.compress_ratio > 1 + ) + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + super().cache_blocks(request, num_tokens) + if ( + not self._should_protect_prompt_blocks() + or num_tokens < request.num_prompt_tokens + or request.num_prompt_tokens <= 1 + ): + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + num_hit_blocks = aligned_cache_hit_length // self.block_size + if num_hit_blocks == 0: + return + + self._protect_prompt_blocks( + self.req_to_blocks[request.request_id][:num_hit_blocks] + ) + + def get_num_common_prefix_blocks(self, running_request_id: str) -> int: + blocks = self.req_to_blocks[running_request_id] + num_common_blocks = 0 + expected_ref_cnt = len(self.req_to_blocks) + for block in blocks: + ref_cnt = block.ref_cnt + if block.block_id in self._protected_prompt_block_ids: + ref_cnt -= 1 + if ref_cnt == expected_ref_cnt: + num_common_blocks += 1 + else: + break + return num_common_blocks + + class SlidingWindowManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: SlidingWindowSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -641,6 +758,42 @@ def get_num_common_prefix_blocks(self, running_request_id: str) -> int: return 0 +class SlidingWindowMLAManager(SlidingWindowManager): + """KV cache manager for DeepSeek V4's sliding-window MLA cache. + + During decode, the live sliding window can move past the prompt boundary. + The blocks around the hybrid-aligned prompt boundary are still the suffix + needed for a future prefix-cache hit of the same prompt. + """ + + def cache_blocks(self, request: Request, num_tokens: int) -> None: + super().cache_blocks(request, num_tokens) + if not self.enable_caching or num_tokens < request.num_prompt_tokens: + return + if request.num_prompt_tokens <= 1: + return + + max_cache_hit_length = request.num_prompt_tokens - 1 + aligned_cache_hit_length = ( + max_cache_hit_length + // self.cache_alignment_tokens + * self.cache_alignment_tokens + ) + if aligned_cache_hit_length <= 0: + return + + aligned_num_hit_blocks = aligned_cache_hit_length // self.block_size + last_full_prompt_block = max_cache_hit_length // self.block_size + contiguous_blocks = cdiv(self.sliding_window - 1, self.block_size) + first_protected_block = max(0, aligned_num_hit_blocks - contiguous_blocks) + last_protected_block = max(aligned_num_hit_blocks, last_full_prompt_block) + blocks = self.req_to_blocks[request.request_id] + protected_blocks = blocks[ + first_protected_block : min(last_protected_block, len(blocks)) + ] + self._protect_prompt_blocks(protected_blocks) + + class ChunkedLocalAttentionManager(SingleTypeKVCacheManager): def __init__(self, kv_cache_spec: ChunkedLocalAttentionSpec, **kwargs) -> None: super().__init__(kv_cache_spec, **kwargs) @@ -1124,6 +1277,7 @@ def __init__( kv_cache_group_id: int, dcp_world_size: int = 1, pcp_world_size: int = 1, + max_model_len: int | None = None, ): super().__init__( kv_cache_spec, @@ -1132,6 +1286,7 @@ def __init__( kv_cache_group_id, dcp_world_size, pcp_world_size, + max_model_len=max_model_len, ) sink_len = kv_cache_spec.sink_len assert sink_len is not None and sink_len > 0 and sink_len % self.block_size == 0 @@ -1142,9 +1297,9 @@ def __init__( spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, TQFullAttentionSpec: FullAttentionManager, - MLAAttentionSpec: FullAttentionManager, + MLAAttentionSpec: MLAAttentionManager, SlidingWindowSpec: SlidingWindowManager, - SlidingWindowMLASpec: SlidingWindowManager, + SlidingWindowMLASpec: SlidingWindowMLAManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, CrossAttentionSpec: CrossAttentionManager, @@ -1159,6 +1314,7 @@ def get_manager_for_kv_cache_spec( **kwargs, ) -> SingleTypeKVCacheManager: manager_class = spec_manager_map[type(kv_cache_spec)] + kwargs["max_model_len"] = max_model_len # SlidingWindow / ChunkedLocalAttention managers recycle blocks across # chunks; the runtime admission cap must match the recycling-aware bound # the startup pool sizer uses (single source of truth: the spec method). diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py index 1541b24deaaf..5d3fe0dbd166 100644 --- a/vllm/v1/executor/ray_utils.py +++ b/vllm/v1/executor/ray_utils.py @@ -273,6 +273,22 @@ def assert_ray_available(): ) +def _warn_if_insufficient_cluster_devices( + parallel_config: ParallelConfig, device_str: str +) -> None: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: + logger.warning( + "The requested distributed world size (%d) exceeds the total " + "number of available %ss (%d) in the Ray cluster. This may result " + "in Ray placement group allocation failures. Check `ray status` " + "and `ray list nodes`, or reduce tensor/pipeline parallel size.", + parallel_config.world_size, + device_str, + num_devices_in_cluster, + ) + + def _verify_bundles( placement_group: "PlacementGroup", parallel_config: ParallelConfig, @@ -549,21 +565,6 @@ def initialize_ray_cluster( if os.environ.get("RAY_USAGE_STATS_ENABLED", "0") != "1": os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - # Prevalidate GPU requirements before Ray processing - if current_platform.is_cuda() and parallel_config.world_size > 1: - available_gpus = current_platform.device_count() - if parallel_config.world_size > available_gpus: - logger.warning( - "Tensor parallel size (%d) exceeds available GPUs (%d). " - "This may result in Ray placement group allocation failures. " - "Consider reducing tensor_parallel_size to %d or less, " - "or ensure your Ray cluster has %d GPUs available.", - parallel_config.world_size, - available_gpus, - available_gpus, - parallel_config.world_size, - ) - if ray.is_initialized(): logger.info("Ray is already initialized. Skipping Ray initialization.") elif current_platform.is_rocm() or current_platform.is_xpu(): @@ -589,6 +590,9 @@ def initialize_ray_cluster( f"current platform {current_platform.device_name} does not support ray." ) + if parallel_config.world_size > 1: + _warn_if_insufficient_cluster_devices(parallel_config, device_str) + # Create or get the placement group for worker processes if parallel_config.placement_group: current_placement_group = parallel_config.placement_group @@ -619,17 +623,6 @@ def initialize_ray_cluster( ) else: logger.info("No current placement group found. Creating a new placement group.") - num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) - # Log a warning message and delay resource allocation failure response. - # Avoid immediate rejection to allow user-initiated placement group - # created and wait cluster to be ready - if parallel_config.world_size > num_devices_in_cluster: - logger.warning( - "The number of required %ss exceeds the total " - "number of available %ss in the placement group.", - device_str, - device_str, - ) # Create a new placement group placement_group_specs: list[dict[str, float]] = [ {device_str: 1.0} for _ in range(parallel_config.world_size)