diff --git a/vllm/model_executor/layers/deepseek_compressor.py b/vllm/model_executor/layers/deepseek_compressor.py index af2783f604da..262467fa84db 100644 --- a/vllm/model_executor/layers/deepseek_compressor.py +++ b/vllm/model_executor/layers/deepseek_compressor.py @@ -329,7 +329,6 @@ def forward( TRITON_BLOCK_SIZE=triton.next_power_of_2(kv.shape[-1]), STATE_WIDTH=state_width, COMPRESS_RATIO=self.compress_ratio, - launch_pdl=False, ) # Fused: compress → RMSNorm → RoPE → FP8 quant → KV cache write. @@ -378,7 +377,6 @@ def forward( SCALE_DIM=self._scale_dim, KV_BLOCK_STRIDE=kv_cache.stride(0), num_warps=self._num_warps, - launch_pdl=False, ) diff --git a/vllm/model_executor/layers/deepseek_v4_attention.py b/vllm/model_executor/layers/deepseek_v4_attention.py index 43242eddb5b2..7f128c9db4bf 100644 --- a/vllm/model_executor/layers/deepseek_v4_attention.py +++ b/vllm/model_executor/layers/deepseek_v4_attention.py @@ -5,6 +5,7 @@ """ from dataclasses import dataclass +import os from typing import TYPE_CHECKING, cast import torch @@ -22,6 +23,7 @@ combine_topk_swa_indices, compute_global_topk_indices_and_lens, dequantize_and_gather_k_cache, + quantize_and_insert_k_cache, fused_indexer_q_rope_quant, fused_inv_rope_fp8_quant, fused_q_kv_rmsnorm, @@ -43,6 +45,7 @@ from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.deepseek_compressor import DeepseekCompressor +from vllm.model_executor.layers.deepseek_v4_debug import dump_tensor from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.input_quant_fp8 import ( @@ -72,6 +75,75 @@ logger = init_logger(__name__) +_DEBUG_COUNTS: dict[str, int] = {} + + +def _dsv4_debug_enabled(name: str) -> bool: + value = os.environ.get(name, "") + return value not in ("", "0", "false", "False", "no", "No") + + +def _dsv4_debug_log(name: str, msg: str, limit: int = 8) -> None: + count = _DEBUG_COUNTS.get(name, 0) + if count < limit: + logger.warning("[DSV4_DEBUG:%s] %s", name, msg) + _DEBUG_COUNTS[name] = count + 1 + + +def _dsv4_tensor_summary(t: torch.Tensor) -> str: + if t.numel() == 0: + return f"shape={tuple(t.shape)} dtype={t.dtype} empty" + tf = t.detach() + try: + finite_ok = torch.isfinite(tf).all().item() if tf.is_floating_point() else True + except NotImplementedError: + finite_ok = "unsupported" + sample = tf.float() if tf.is_floating_point() else tf.to(torch.float32) + return ( + f"shape={tuple(t.shape)} dtype={t.dtype} " + f"mean={sample.mean().item():.4e} std={sample.std(unbiased=False).item():.4e} " + f"amax={sample.abs().amax().item():.4e} finite={finite_ok}" + ) + + +def _dsv4_decode_scale(scale: torch.Tensor) -> torch.Tensor: + if hasattr(torch, "float8_e8m0fnu") and scale.dtype == torch.float8_e8m0fnu: + return torch.exp2(scale.view(torch.uint8).to(torch.float32) - 127.0) + return scale.to(torch.float32) + + +def _dsv4_expand_block_scale(scale: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + scale_f = _dsv4_decode_scale(scale) + if scale_f.shape == weight.shape: + return scale_f + if scale_f.ndim == 2 and weight.ndim == 2: + block_m = (weight.shape[0] + scale_f.shape[0] - 1) // scale_f.shape[0] + block_k = (weight.shape[1] + scale_f.shape[1] - 1) // scale_f.shape[1] + return scale_f.repeat_interleave(block_m, 0).repeat_interleave(block_k, 1)[ + : weight.shape[0], : weight.shape[1] + ] + if scale_f.numel() == 1: + return scale_f.reshape(1).expand_as(weight) + # Last-resort diagnostic path: preserve shape correctness over speed. + flat = scale_f.reshape(-1) + repeat = (weight.numel() + flat.numel() - 1) // flat.numel() + return flat.repeat_interleave(repeat)[: weight.numel()].reshape_as(weight) + + +def _dsv4_linear_fp8_reference(module: torch.nn.Module, x: torch.Tensor) -> torch.Tensor: + weight = getattr(module, "weight", None) + scale = getattr(module, "weight_scale_inv", None) + if weight is None or scale is None: + raise RuntimeError(f"{module.__class__.__name__} has no FP8 weight/scale tensors") + w = weight.detach() + w_f = w.to(torch.float32) * _dsv4_expand_block_scale(scale.detach(), w) + y = F.linear(x.to(torch.float32), w_f) + bias = getattr(module, "bias", None) + if bias is not None: + y = y + bias.to(torch.float32) + return y.to(x.dtype) + + # Prefill is processed in fixed-size chunks; this bounds the bf16 kv-gather # workspace allocated at _forward_prefill (and the matching profile-time # reservation in attention_impl's dummy-run branch). @@ -277,8 +349,20 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None = None, ) -> torch.Tensor: - qr_kv, _ = self.fused_wqa_wkv(hidden_states) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_UNFUSE_QKV_DOWN"): + qr_kv = _dsv4_linear_fp8_reference(self.fused_wqa_wkv, hidden_states) + else: + qr_kv, _ = self.fused_wqa_wkv(hidden_states) + dump_tensor(self.layer_name + ".fused_wqa_wkv", qr_kv) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_LOG_STATS"): + _dsv4_debug_log( + self.layer_name + ".fused_wqa_wkv", + "hidden=" + _dsv4_tensor_summary(hidden_states) + + " qr_kv=" + _dsv4_tensor_summary(qr_kv), + ) qr, kv = qr_kv.split([self.q_lora_rank, self.head_dim], dim=-1) + dump_tensor(self.layer_name + ".qr", qr) + dump_tensor(self.layer_name + ".kv", kv) # Pre-allocate attention output with FlashMLA-padded head count. # The op writes into `o_padded`; we slice to n_local_heads after. @@ -299,6 +383,7 @@ def forward( self.layer_name, ) o = o_padded[:, : self.n_local_heads, :] + dump_tensor(self.layer_name + ".mla_attn_out", o) # O projection: inverse RoPE + FP8 quant + einsum + wo_b o_fp8, o_scale = fused_inv_rope_fp8_quant( @@ -313,7 +398,18 @@ def forward( ) wo_a_fp8 = self.wo_a.weight + dump_tensor(self.layer_name + ".o_fp8", o_fp8) + dump_tensor(self.layer_name + ".o_scale", o_scale) wo_a_scale = self.wo_a.weight_scale_inv + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_WOA_FORCE_FP32_SCALE"): + wo_a_scale = _dsv4_decode_scale(wo_a_scale) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_LOG_STATS"): + _dsv4_debug_log( + self.layer_name + ".wo_a_inputs", + "o_fp8=" + _dsv4_tensor_summary(o_fp8) + + " o_scale=" + _dsv4_tensor_summary(o_scale) + + " wo_a_scale=" + _dsv4_tensor_summary(wo_a_scale), + ) z = torch.empty( (num_tokens, self.n_local_groups, self.o_lora_rank), @@ -329,8 +425,14 @@ def forward( "bhr,hdr->bhd", list(self._einsum_recipe), ) + dump_tensor(self.layer_name + ".wo_a_z", z) - return self.wo_b(z.flatten(1)) + out = self.wo_b(z.flatten(1)) + if isinstance(out, tuple): + dump_tensor(self.layer_name + ".wo_b", out[0]) + else: + dump_tensor(self.layer_name + ".wo_b", out) + return out def attention_impl( self, @@ -350,7 +452,10 @@ def attention_impl( self.kv_norm.weight.data, self.eps, ) + dump_tensor(self.layer_name + ".qr_norm", qr) + dump_tensor(self.layer_name + ".kv_norm", kv) q = self.wq_b(qr).view(-1, self.n_local_heads, self.head_dim) + dump_tensor(self.layer_name + ".wq_b_q", q) # Overlap kv_insert with whichever of indexer/compressor is present. # Indexer implies compressor; when both exist, compressor rides on the @@ -415,6 +520,129 @@ def kv_insert_and_compress() -> None: # MLA attention writes into the pre-allocated `out` buffer # ([num_tokens, padded_heads, head_dim]). self.mla_attn(q, kv, positions, output=out) + dump_tensor(self.layer_name + ".attention_impl_out", out) + + def _rocm_qnorm_rope_kv_insert( + self, + q: torch.Tensor, + kv: torch.Tensor, + swa_kv_cache: torch.Tensor, + swa_metadata: object, + positions: torch.Tensor, + ) -> None: + """Pure-PyTorch ROCm fallback for the CUDA fused op. + + Q side: per-head RMSNorm (no weight) + GPT-J RoPE (in-place on q). + KV side: GPT-J RoPE on last rope_head_dim dims, then pack into the + 584-byte DeepSeek V4 SWA cache format and scatter-write. + + SWA cache layout per token (584 bytes): + [0:448] 448 × float8_e4m3 NoPE values + [448:576] 64 × bfloat16 RoPE values (128 bytes, NOT quantized) + [576:583] 7 × ue8m0 scale bytes (one per 64-value NoPE block) + [583:584] 1 byte padding + """ + rope_head_dim = self.rope_head_dim # 64 + nope_head_dim = self.nope_head_dim # 448 + half_r = rope_head_dim // 2 + dev = q.device + + cos_sin = self.rotary_emb.cos_sin_cache[positions] # [T, rope_head_dim] + cos = cos_sin[:, :half_r] # [T, half_r=32] + sin = cos_sin[:, half_r:] # [T, half_r=32] + + # ---- Q: per-head RMSNorm (no learnable weight, in-place) ---- + q_f = q.float() # [T, n_heads, head_dim] + rms = torch.rsqrt(q_f.pow(2).mean(-1, keepdim=True) + self.eps) + q.copy_((q_f * rms).to(q.dtype)) + + # ---- Q: GPT-J RoPE on last rope_head_dim dims of each head ---- + q_r_f = q[:, :, nope_head_dim:].float() # [T, n_heads, 64] + q_r0 = q_r_f[:, :, 0::2] # [T, n_heads, 32] + q_r1 = q_r_f[:, :, 1::2] + cos_q = cos.unsqueeze(1) # [T, 1, 32] + sin_q = sin.unsqueeze(1) + q_rope_new = torch.stack( + [q_r0 * cos_q - q_r1 * sin_q, q_r0 * sin_q + q_r1 * cos_q], dim=-1 + ).flatten(-2) # [T, n_heads, 64] + q[:, :, nope_head_dim:].copy_(q_rope_new.to(q.dtype)) + + # ---- KV: GPT-J RoPE on last rope_head_dim dims of kv [T, 512] ---- + kv_nope_f = kv[:, :nope_head_dim].float() # [T, 448] + kv_rope_f = kv[:, nope_head_dim:].float() # [T, 64] + kv_r0 = kv_rope_f[:, 0::2] # [T, 32] + kv_r1 = kv_rope_f[:, 1::2] + kv_rope_new = torch.stack( + [kv_r0 * cos - kv_r1 * sin, kv_r0 * sin + kv_r1 * cos], dim=-1 + ).flatten(-2) # [T, 64] + + if not _dsv4_debug_enabled("VLLM_DSV4_ROCM_KV_INSERT_LEGACY"): + # ROCm correctness path: use the shared Triton insert helper so the + # cache layout matches dequantize_and_gather_k_cache and the ROCm + # FlashMLA decode fallback: block token-data region followed by the + # block scale region. The legacy manual packer stored scale bytes + # inline per token and causes wrong dequantization. + slot_mapping = swa_metadata.slot_mapping # type: ignore[attr-defined] + block_size = swa_kv_cache.shape[1] + k_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + k_for_cache = torch.cat([kv_nope_f, kv_rope_new], dim=-1).to(torch.bfloat16) + quantize_and_insert_k_cache( + k_for_cache, + k_cache_2d, + slot_mapping, + block_size=block_size, + is_ue8m0=True, + ) + return + + # ---- Legacy manual packer (debug-only fallback) ---- + # Layout: [448 fp8_e4m3 NoPE][128 bf16 RoPE][7 ue8m0 scales][1 pad] + # + # NoPE: block quantize with block_size=64 → 7 blocks → 7 UE8M0 scales. + # UE8M0 scale for block: byte = clamp(floor(log2(max_abs)) + 127, 0, 255) + # quantized = round(nope_val / 2^(scale_byte-127)) clamped to fp8_e4m3. + T = kv.shape[0] + fp8_e4m3_max = 448.0 # float8_e4m3fnuz max on ROCm is 240; use e4m3 max + # DeepSeek V4 uses float8_e4m3 (CUDA fp8, max=448) for the NoPE cache. + # On ROCm we approximate with float8_e4m3fnuz (max=240) — close enough. + fp8_max = 240.0 + n_blocks = nope_head_dim // 64 # = 7 + + nope_blocks = kv_nope_f.view(T, n_blocks, 64) # [T, 7, 64] + max_abs = nope_blocks.abs().amax(dim=-1).clamp(min=1e-38) # [T, 7] + # UE8M0: byte = floor(log2(max_abs / fp8_max)) + 127 + log2_scale = torch.floor(torch.log2(max_abs / fp8_max)) + 127.0 + ue8m0_bytes = log2_scale.clamp(0, 255).to(torch.uint8) # [T, 7] + # Actual scale = 2^(byte-127) + actual_scale = torch.exp2(ue8m0_bytes.float() - 127.0) # [T, 7] + # Quantize NoPE: divide by scale, clamp, cast to fp8 viewed as uint8 + nope_quant = (nope_blocks / actual_scale.unsqueeze(-1)).clamp( + -fp8_max, fp8_max + ).to(torch.float8_e4m3fnuz).view(torch.uint8) # [T, 7, 64] + nope_bytes = nope_quant.reshape(T, nope_head_dim) # [T, 448] + + # RoPE part: store as bfloat16 (128 bytes per token) + rope_bytes = kv_rope_new.to(torch.bfloat16).view( + torch.uint8 + ).reshape(T, rope_head_dim * 2) # [T, 128] + + # Scale bytes: 7 ue8m0 + 1 padding + pad_byte = torch.zeros(T, 1, dtype=torch.uint8, device=dev) + scale_bytes = torch.cat([ue8m0_bytes, pad_byte], dim=-1) # [T, 8] + + # Assemble the 584-byte entry + entry = torch.cat([nope_bytes, rope_bytes, scale_bytes], dim=-1) # [T, 584] + + # ---- Scatter-write to paged SWA cache ---- + # swa_kv_cache: [num_blocks, block_size, 584] uint8 + slot_mapping = swa_metadata.slot_mapping # type: ignore[attr-defined] + block_size = swa_kv_cache.shape[1] + num_slots = slot_mapping.shape[0] + if num_slots < T: + entry = entry[:num_slots] + block_idx = slot_mapping // block_size # [S] + inblock_offset = slot_mapping % block_size # [S] + swa_kv_cache[block_idx, inblock_offset, :] = entry def _fused_qnorm_rope_kv_insert( self, @@ -437,6 +665,15 @@ def _fused_qnorm_rope_kv_insert( swa_kv_cache = self.swa_cache_layer.kv_cache swa_kv_cache_2d = swa_kv_cache.view(swa_kv_cache.shape[0], -1) + from vllm.platforms import current_platform + if current_platform.is_rocm(): + # The CUDA fused op does not exist on ROCm; use the pure-PyTorch + # fallback that handles Q-norm, RoPE, FP8 quant, and cache insert. + self._rocm_qnorm_rope_kv_insert( + q, kv, swa_kv_cache, swa_metadata, positions + ) + return + # Horizontally fused: # Q side: q_head_norm (per-head RMSNorm, no weight) + GPT-J RoPE # KV side: GPT-J RoPE + UE8M0 FP8 quant + paged cache insert @@ -485,6 +722,63 @@ def deepseek_v4_attention_fake( ) +def _dsv4_fp8_einsum_torch_fallback( + a: torch.Tensor, + a_scale: torch.Tensor, + b: torch.Tensor, + b_scale: torch.Tensor, + out: torch.Tensor, + equation: str, + recipe: tuple[int, int, int], +) -> None: + if equation != "bhr,hdr->bhd": + raise NotImplementedError( + f"DeepSeek-V4 FP8 einsum fallback only supports bhr,hdr->bhd, got {equation}" + ) + m_block, n_block, k_block = recipe + h_groups = a.shape[-2] + d_out = out.shape[-1] + r_contract = a.shape[-1] + + if b.dim() == 2: + b_3d = b.view(h_groups, d_out, r_contract) + elif b.dim() == 3: + b_3d = b + else: + raise RuntimeError(f"Unexpected wo_a weight dim: {b.dim()}") + + n_d_scale = (d_out + n_block - 1) // n_block if n_block > 1 else d_out + n_r_scale = (r_contract + k_block - 1) // k_block if k_block > 1 else r_contract + b_scale_f = _dsv4_decode_scale(b_scale) + if b_scale_f.dim() == 2: + b_scale_f = b_scale_f.view(h_groups, n_d_scale, n_r_scale) + + a_scale_f = _dsv4_decode_scale(a_scale).contiguous() + if k_block > 1: + a_scale_f = a_scale_f.repeat_interleave(k_block, dim=-1) + a_scale_f = a_scale_f[..., :r_contract] + if m_block > 1: + a_scale_f = a_scale_f.repeat_interleave(m_block, dim=0)[: a.shape[0]] + a_bf16 = (a.to(torch.float32) * a_scale_f).to(torch.bfloat16) + + if k_block > 1: + b_scale_f = b_scale_f.repeat_interleave(k_block, dim=-1) + if n_block > 1: + b_scale_f = b_scale_f.repeat_interleave(n_block, dim=-2) + b_scale_f = b_scale_f[..., :d_out, :r_contract] + b_bf16 = (b_3d.to(torch.float32) * b_scale_f).to(torch.bfloat16) + + if _dsv4_debug_enabled("VLLM_DSV4_ROCM_WOA_BMM"): + z_hbd = torch.bmm( + a_bf16.transpose(0, 1).contiguous(), + b_bf16.transpose(1, 2).contiguous(), + ) + out.copy_(z_hbd.transpose(0, 1).to(out.dtype)) + return + + out.copy_(torch.einsum(equation, a_bf16, b_bf16).to(out.dtype)) + + def deepseek_v4_fp8_einsum( a: torch.Tensor, a_scale: torch.Tensor, @@ -494,6 +788,13 @@ def deepseek_v4_fp8_einsum( equation: str, recipe: list[int], ) -> None: + from vllm.platforms import current_platform + + if current_platform.is_rocm(): + _dsv4_fp8_einsum_torch_fallback( + a, a_scale, b, b_scale, out, equation, tuple(recipe) + ) + return fp8_einsum(equation, (a, a_scale), (b, b_scale), out, recipe=tuple(recipe)) @@ -775,6 +1076,8 @@ def _forward_decode( "allocate one for this layer type." ) + dump_tensor(self.prefix + ".decode_q", q) + dump_tensor(self.prefix + ".decode_swa_lens", swa_lens) out, _ = flash_mla_with_kvcache( q=q, k_cache=swa_cache, @@ -786,12 +1089,17 @@ def _forward_decode( indices=swa_indices, topk_length=swa_lens, softmax_scale=self.scale, - attn_sink=self.attn_sink, + attn_sink=( + torch.full_like(self.attn_sink, -float("inf")) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_DISABLE_ATTN_SINK") + else self.attn_sink + ), extra_k_cache=kv_cache if not swa_only else None, extra_indices_in_kvcache=topk_indices, extra_topk_length=topk_lens, out=output.unsqueeze(1), ) + dump_tensor(self.prefix + ".decode_output", output) def _forward_prefill( self, @@ -851,6 +1159,11 @@ def _forward_prefill( kv = workspace_manager.get_simultaneous( ((PREFILL_CHUNK_SIZE, M, q.shape[-1]), torch.bfloat16), )[0] + if _dsv4_debug_enabled("VLLM_DSV4_ZERO_PREFILL_KV"): + # Debug correctness fallback: FlashMLA sparse should only read + # gathered positions, but zero-initializing catches accidental reads + # from workspace gaps and makes first-divergence triage deterministic. + kv.zero_() for chunk_idx in range(num_chunks): chunk_start = chunk_idx * PREFILL_CHUNK_SIZE chunk_end = min(chunk_start + PREFILL_CHUNK_SIZE, num_prefills) @@ -868,6 +1181,7 @@ def _forward_prefill( block_size=attn_metadata.block_size // self.compress_ratio, offset=0, ) + dump_tensor(self.prefix + ".prefill_after_compressed_gather", kv[:chunk_size]) # Gather SWA KV swa_block_table = swa_metadata.block_table[num_decodes:] @@ -880,6 +1194,7 @@ def _forward_prefill( block_size=swa_metadata.block_size, offset=N, ) + dump_tensor(self.prefix + ".prefill_after_swa_gather", kv[:chunk_size]) # Combine the topk indices and SWA indices for gathered KV cache query_start = ( @@ -903,15 +1218,22 @@ def _forward_prefill( N, ) + dump_tensor(self.prefix + ".prefill_q_chunk", q[query_start:query_end]) + dump_tensor(self.prefix + ".prefill_kv_chunk", kv[:chunk_size]) output_chunk, _, _ = flash_mla_sparse_fwd( q=q[query_start:query_end], kv=kv.view(-1, 1, q.shape[-1]), indices=combined_indices.unsqueeze(1), sm_scale=self.scale, - attn_sink=self.attn_sink, + attn_sink=( + torch.full_like(self.attn_sink, -float("inf")) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_DISABLE_ATTN_SINK") + else self.attn_sink + ), topk_length=combined_lens, out=output[query_start:query_end], ) + dump_tensor(self.prefix + ".prefill_output_chunk", output[query_start:query_end]) class DeepseekV4IndexerCache(torch.nn.Module, AttentionLayerBase): @@ -1060,10 +1382,24 @@ def forward( positions: torch.Tensor, rotary_emb: nn.Module, ) -> torch.Tensor: + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_DISABLE_SPARSE_INDEXER"): + assert self.topk_indices_buffer is not None + num_tokens = hidden_states.shape[0] + fill_width = min(self.topk_tokens, max(1, num_tokens)) + vals = torch.arange(fill_width, device=hidden_states.device, dtype=self.topk_indices_buffer.dtype) + self.topk_indices_buffer[:num_tokens, :fill_width] = vals.unsqueeze(0) + if fill_width < self.topk_tokens: + self.topk_indices_buffer[:num_tokens, fill_width:self.topk_tokens] = vals[-1] + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_LOG_STATS"): + _dsv4_debug_log(self.prefix + ".indexer_bypass", f"num_tokens={num_tokens} fill_width={fill_width}") + return self.topk_indices_buffer q, _ = self.wq_b(qr) q = q.view(-1, self.n_head, self.head_dim) + dump_tensor(self.prefix + ".indexer_q", q) k = self.compressor(hidden_states, positions, rotary_emb) + dump_tensor(self.prefix + ".indexer_k", k) weights, _ = self.weights_proj(hidden_states) + dump_tensor(self.prefix + ".indexer_weights_raw", weights) q_quant, weights = fused_indexer_q_rope_quant( positions, q, @@ -1073,4 +1409,12 @@ def forward( self.n_head**-0.5, use_fp4=self.use_fp4_kv, ) - return self.indexer_op(hidden_states, q_quant, k, weights) + if isinstance(q_quant, tuple): + dump_tensor(self.prefix + ".indexer_q_quant_values", q_quant[0]) + dump_tensor(self.prefix + ".indexer_q_quant_scale", q_quant[1]) + else: + dump_tensor(self.prefix + ".indexer_q_quant", q_quant) + dump_tensor(self.prefix + ".indexer_weights", weights) + out = self.indexer_op(hidden_states, q_quant, k, weights) + dump_tensor(self.prefix + ".indexer_topk", out) + return out diff --git a/vllm/model_executor/layers/deepseek_v4_debug.py b/vllm/model_executor/layers/deepseek_v4_debug.py new file mode 100644 index 000000000000..c2179bfc40f5 --- /dev/null +++ b/vllm/model_executor/layers/deepseek_v4_debug.py @@ -0,0 +1,153 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Debug-only dump helpers for DeepSeek-V4 accuracy triage. + +All functionality is disabled unless VLLM_DSV4_DUMP_ROOT is set. +""" + +from __future__ import annotations + +import json +import os +import re +import time +from pathlib import Path +from typing import Any + +import torch + +_COUNTS: dict[str, int] = {} +_LAYER_RE = re.compile(r"(?:^|\.)layers\.(\d+)(?:\.|$)") + + +def _truthy(value: str | None) -> bool: + return value not in (None, "", "0", "false", "False", "no", "No") + + +def enabled() -> bool: + return _truthy(os.environ.get("VLLM_DSV4_DUMP_ROOT")) + + +def side() -> str: + return os.environ.get("VLLM_DSV4_DUMP_SIDE", "fp8") + + +def _rank() -> str: + return os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0" + + +def layer_from_name(name: str) -> int | None: + m = _LAYER_RE.search(name) + return int(m.group(1)) if m else None + + +def _selected_layer(name: str, layer_idx: int | None) -> bool: + spec = os.environ.get("VLLM_DSV4_DUMP_LAYERS", "") + if not spec: + return True + if layer_idx is None: + return True + vals: set[int] = set() + for item in spec.split(","): + item = item.strip() + if not item: + continue + if "-" in item: + lo, hi = item.split("-", 1) + vals.update(range(int(lo), int(hi) + 1)) + else: + vals.add(int(item)) + return layer_idx in vals + + +def _safe_float(x: torch.Tensor, fn: str) -> float | None: + try: + if fn == "mean": + return float(x.mean().item()) + if fn == "std": + return float(x.std(unbiased=False).item()) + if fn == "amax": + return float(x.abs().amax().item()) + if fn == "amin": + return float(x.amin().item()) + if fn == "max": + return float(x.amax().item()) + except Exception: + return None + return None + + +def _summary(tensor: torch.Tensor) -> dict[str, Any]: + t = tensor.detach() + out: dict[str, Any] = { + "shape": list(t.shape), + "dtype": str(t.dtype), + "device": str(t.device), + "numel": int(t.numel()), + } + if t.numel() == 0: + out.update({"finite": True, "nan_count": 0, "inf_count": 0}) + return out + try: + tf = t.float() + except Exception: + tf = t.to(torch.float32) + try: + finite = torch.isfinite(tf) + out["finite"] = bool(finite.all().item()) + out["nan_count"] = int(torch.isnan(tf).sum().item()) + out["inf_count"] = int(torch.isinf(tf).sum().item()) + except Exception: + out["finite"] = None + out["nan_count"] = None + out["inf_count"] = None + out["mean"] = _safe_float(tf, "mean") + out["std"] = _safe_float(tf, "std") + out["amax_abs"] = _safe_float(tf, "amax") + out["min"] = _safe_float(tf, "amin") + out["max"] = _safe_float(tf, "max") + if tf.ndim >= 1 and tf.shape[-1] > 0 and tf.numel() <= 2_000_000: + try: + row = tf.reshape(-1, tf.shape[-1])[-1] + k = min(8, row.numel()) + vals, idx = torch.topk(row, k=k) + out["last_row_top_ids"] = [int(x) for x in idx.cpu().tolist()] + out["last_row_top_vals"] = [float(x) for x in vals.cpu().tolist()] + except Exception: + pass + return out + + +def dump_tensor(name: str, tensor: torch.Tensor | None, *, layer_idx: int | None = None, + note: str | None = None, max_writes: int | None = None) -> None: + if not enabled() or tensor is None: + return + if layer_idx is None: + layer_idx = layer_from_name(name) + if not _selected_layer(name, layer_idx): + return + key = f"{side()}:{_rank()}:{name}" + if max_writes is None: + max_writes = int(os.environ.get("VLLM_DSV4_DUMP_MAX_WRITES", "16")) + count = _COUNTS.get(key, 0) + if count >= max_writes: + return + _COUNTS[key] = count + 1 + root = Path(os.environ["VLLM_DSV4_DUMP_ROOT"]) / side() + root.mkdir(parents=True, exist_ok=True) + rec: dict[str, Any] = { + "ts": time.time(), + "side": side(), + "rank": _rank(), + "name": name, + "layer_idx": layer_idx, + "write_idx": count, + "note": note, + } + rec.update(_summary(tensor)) + with (root / f"rank_{_rank()}_summary.jsonl").open("a") as f: + f.write(json.dumps(rec, ensure_ascii=False, allow_nan=True) + "\n") + if _truthy(os.environ.get("VLLM_DSV4_DUMP_FULL_TENSOR")): + import numpy as np + safe = re.sub(r"[^A-Za-z0-9_.-]+", "_", name) + arr = tensor.detach().float().cpu().numpy() + np.savez_compressed(root / f"rank_{_rank()}_{count:04d}_{safe}.npz", tensor=arr) diff --git a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py index f476d980d555..800808210f20 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py +++ b/vllm/model_executor/layers/fused_moe/oracle/mxfp4.py @@ -220,9 +220,8 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]: _AVAILABLE_BACKENDS = [ Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8, Mxfp4MoeBackend.DEEPGEMM_MXFP4, - # TRITON_UNFUSED has bug with MTP support - # TODO re-enable after kernel is fixed - # TRITON_UNFUSED + Mxfp4MoeBackend.TRITON, + Mxfp4MoeBackend.TRITON_UNFUSED, Mxfp4MoeBackend.MARLIN, Mxfp4MoeBackend.BATCHED_MARLIN, ] @@ -836,14 +835,24 @@ def _interleave_mxfp4_cutlass_sm90(w): w2_weight.data = w2_weight.data.view(torch.float4_e2m1fn_x2) # Shuffle weights and scales for AITER CK kernel layout - w13_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True) + w13_weight = torch.nn.Parameter( + rocm_aiter_ops.shuffle_weight_a16w4(w13_weight, 16, True).view( + torch.float4_e2m1fn_x2 + ), + requires_grad=False, + ) shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( w13_weight_scale.view(-1, w13_weight_scale.shape[-1]), num_experts, True, ) - w2_weight.data = rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False) + w2_weight = torch.nn.Parameter( + rocm_aiter_ops.shuffle_weight_a16w4(w2_weight, 16, False).view( + torch.float4_e2m1fn_x2 + ), + requires_grad=False, + ) shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( w2_weight_scale.view(-1, w2_weight_scale.shape[-1]), num_experts, @@ -1159,10 +1168,79 @@ def shuffle_weight(w: torch.Tensor) -> torch.Tensor: w13_bias, w2_bias, ) + elif mxfp4_backend == Mxfp4MoeBackend.AITER: + from vllm._aiter_ops import rocm_aiter_ops + + w13_weight = w13_weight.data + w2_weight = w2_weight.data + w13_weight_scale = w13_weight_scale.data + w2_weight_scale = w2_weight_scale.data + if w13_bias is not None: + w13_bias = w13_bias.data.to(torch.float32) + if w2_bias is not None: + w2_bias = w2_bias.data.to(torch.float32) + + e, n, k = w13_weight.shape + + # De-interleave w13 rows: gate/up pairs -> contiguous gate, up blocks. + w13_weight = ( + w13_weight.view(torch.uint8) + .view(e, n // 2, 2, k) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, k) + ) + w13_weight_scale = ( + w13_weight_scale.view(e, n // 2, 2, -1) + .permute(0, 2, 1, 3) + .contiguous() + .view(e, n, -1) + ) + + # AITER CK kernels key off torch.float4_e2m1fn_x2, not raw uint8. + w13_weight = torch.nn.Parameter( + rocm_aiter_ops.shuffle_weight_a16w4( + w13_weight.view(torch.float4_e2m1fn_x2), 16, True + ).view(torch.float4_e2m1fn_x2), + requires_grad=False, + ) + w2_weight = torch.nn.Parameter( + rocm_aiter_ops.shuffle_weight_a16w4( + w2_weight.view(torch.float4_e2m1fn_x2), 16, False + ).view(torch.float4_e2m1fn_x2), + requires_grad=False, + ) + shuffled_w13_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w13_weight_scale.view(-1, w13_weight_scale.shape[-1]), + num_experts, + True, + ) + shuffled_w2_scale = rocm_aiter_ops.shuffle_scale_a16w4( + w2_weight_scale.view(-1, w2_weight_scale.shape[-1]), + num_experts, + False, + ) + + if w13_bias is not None: + w13_bias = ( + w13_bias.view(-1, n // 2, 2) + .permute(0, 2, 1) + .contiguous() + .view(-1, n) + ) + + return ( + w13_weight, + w2_weight, + shuffled_w13_scale, + shuffled_w2_scale, + w13_bias, + w2_bias, + ) else: raise ValueError( f"Unsupported mxfp4_backend for Mxfp4MoEMethod: {mxfp4_backend}. " - f"Expected TRTLLM or Triton backend." + f"Expected TRTLLM, Triton, AITER, or Marlin backend." ) diff --git a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py index 84eaad7f65e6..34038f5eadb4 100644 --- a/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py +++ b/vllm/model_executor/layers/fused_moe/router/fused_topk_bias_router.py @@ -57,6 +57,56 @@ def vllm_topk_sigmoid( return topk_weights, topk_indices +def _topk_softplus_sqrt_torch( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + token_expert_indices: torch.Tensor, + gating_output: torch.Tensor, + renormalize: bool, + routed_scaling_factor: float, + e_score_correction_bias: torch.Tensor | None, + input_tokens: torch.Tensor | None, + hash_indices_table: torch.Tensor | None, +) -> None: + x_f32 = gating_output.to(torch.float32) + weights_base = torch.sqrt(F.softplus(x_f32, beta=1.0, threshold=20.0)) + topk = topk_weights.shape[-1] + + if input_tokens is not None and hash_indices_table is not None: + selected_experts = hash_indices_table[input_tokens.to(torch.long)] + selected_weights = torch.gather(weights_base, -1, selected_experts.to(torch.long)) + if renormalize: + denom = selected_weights.sum(dim=-1, keepdim=True) + denom = torch.where(denom > 0, denom, torch.ones_like(denom)) + selected_weights = selected_weights / denom + selected_weights = selected_weights * routed_scaling_factor + topk_weights.copy_(selected_weights.to(topk_weights.dtype)) + topk_indices.copy_(selected_experts.to(topk_indices.dtype)) + return + + ranking = weights_base + if e_score_correction_bias is not None: + ranking = ranking + e_score_correction_bias.to(torch.float32) + _, topk_ids = torch.topk(ranking, topk, dim=-1) + out_weights = torch.gather(weights_base, -1, topk_ids) + if renormalize: + denom = out_weights.sum(dim=-1, keepdim=True) + denom = torch.where(denom > 0, denom, torch.ones_like(denom)) + out_weights = out_weights / denom + out_weights = out_weights * routed_scaling_factor + topk_weights.copy_(out_weights.to(topk_weights.dtype)) + topk_indices.copy_(topk_ids.to(topk_indices.dtype)) + + arange_t = torch.arange( + gating_output.shape[0], device=gating_output.device, + dtype=token_expert_indices.dtype, + ).unsqueeze(-1) + arange_k = torch.arange( + topk, device=gating_output.device, dtype=token_expert_indices.dtype, + ).unsqueeze(0) + token_expert_indices.copy_(arange_k * gating_output.shape[0] + arange_t) + + def vllm_topk_softplus_sqrt( topk_weights: torch.Tensor, topk_indices: torch.Tensor, @@ -68,17 +118,32 @@ def vllm_topk_softplus_sqrt( hash_indices_table: torch.Tensor | None = None, routed_scaling_factor: float = 1.0, ) -> tuple[torch.Tensor, ...]: - ops.topk_hash_softplus_sqrt( - topk_weights, - topk_indices, - token_expert_indices, - gating_output, - renormalize, - routed_scaling_factor, - e_score_correction_bias, - input_tokens, - hash_indices_table, - ) + from vllm.platforms import current_platform + + if current_platform.is_rocm(): + _topk_softplus_sqrt_torch( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) + else: + ops.topk_hash_softplus_sqrt( + topk_weights, + topk_indices, + token_expert_indices, + gating_output, + renormalize, + routed_scaling_factor, + e_score_correction_bias, + input_tokens, + hash_indices_table, + ) return topk_weights, topk_indices diff --git a/vllm/model_executor/layers/mhc.py b/vllm/model_executor/layers/mhc.py index 1521a6b601bf..ee7b875cbd75 100644 --- a/vllm/model_executor/layers/mhc.py +++ b/vllm/model_executor/layers/mhc.py @@ -1,184 +1,12 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math -from functools import cache -from typing import TYPE_CHECKING - import torch - +import torch.nn.functional as F from vllm.platforms import current_platform -from vllm.utils.import_utils import has_tilelang -from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import direct_register_custom_op -# tilelang is only available on CUDA platforms -if TYPE_CHECKING or current_platform.is_cuda_alike(): - if not has_tilelang(): - raise ImportError( - "tilelang is required for mhc but is not installed. Install it with " - "`pip install tilelang`." - ) - import tilelang - import tilelang.language as T -else: - tilelang = None # type: ignore[assignment] - T = None # type: ignore[assignment] - - -@cache -def compute_num_split(block_k: int, k: int | None, grid_size: int) -> int: - device_props = torch.cuda.get_device_properties(0) - n_sms = device_props.multi_processor_count - split_k = n_sms // grid_size - if k is not None: - # avoid split_k for small k - num_block_k = cdiv(k, block_k) - split_k = min(split_k, num_block_k // 4) - split_k = max(split_k, 1) - return split_k - - -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, -) -def mhc_pre_big_fuse_tilelang( - gemm_out_mul, - gemm_out_sqrsum, - hc_scale, - hc_base, - residual, - post_mix, - comb_mix, - layer_input, - hidden_size: int, - rms_eps: float, - hc_pre_eps: float, - hc_sinkhorn_eps: float, - hc_post_mult_value: float, - sinkhorn_repeat: int, - n_splits: int = 16, - hc_mult: int = 4, -): - """Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block.""" - num_tokens = T.dynamic("num_tokens") - hc_mult3 = hc_mult * (2 + hc_mult) - hidden_block = math.gcd(512, hidden_size) - - gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] # type: ignore[no-redef, valid-type] - gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] # type: ignore[no-redef, valid-type] - hc_scale: T.Tensor[[3], T.float32] # type: ignore[no-redef, valid-type] - hc_base: T.Tensor[[hc_mult3], T.float32] # type: ignore[no-redef, valid-type] - residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type] - # outputs - post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] # type: ignore[no-redef, valid-type] - comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] # type: ignore[no-redef, valid-type] - layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] # type: ignore[no-redef, valid-type] - - with T.Kernel(num_tokens, threads=96) as i: - T.pdl_sync() - ################################################################## - # _pre_norm_fn_fwd_norm - rms = T.alloc_fragment(1, T.float32) - mixes = T.alloc_fragment(hc_mult3, T.float32) - T.clear(mixes) - rms[0] = 0 - for i_split in T.serial(n_splits): - rms[0] += gemm_out_sqrsum[i_split, i] - rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) - for j in T.Parallel(hc_mult3): - mixes[j] = 0 - for i_split in T.serial(n_splits): - mixes[j] += gemm_out_mul[i_split, i, j] - mixes[j] *= rms[0] - mixes_shared = T.alloc_shared(hc_mult3, T.float32) - T.copy(mixes, mixes_shared) - - if T.get_thread_binding() < 32: - ################################################################## - # _pre_split_mixes_fwd (post & comb) - cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) - for j in T.Parallel(hc_mult): - post_mix[i, j] = ( - T.sigmoid( - mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult] - ) - * hc_post_mult_value - ) - for j, k in T.Parallel(hc_mult, hc_mult): - cm[j, k] = ( - mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] - + hc_base[j * hc_mult + k + hc_mult * 2] - ) - - ################################################################## - # _sinkhorn_fwd - row_sum = T.alloc_fragment(hc_mult, T.float32) - col_sum = T.alloc_fragment(hc_mult, T.float32) +_aiter_mhc = None - # comb = comb.softmax(-1) + eps - row_max = T.alloc_fragment(hc_mult, T.float32) - T.reduce_max(cm, row_max, dim=1) - for j, k in T.Parallel(hc_mult, hc_mult): - cm[j, k] = T.exp(cm[j, k] - row_max[j]) - T.reduce_sum(cm, row_sum, dim=1) - for j, k in T.Parallel(hc_mult, hc_mult): - cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps - # comb = comb / (comb.sum(-2) + eps) - T.reduce_sum(cm, col_sum, dim=0) - for j, k in T.Parallel(hc_mult, hc_mult): - cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) - - for _ in T.serial(sinkhorn_repeat - 1): - # comb = comb / (comb.sum(-1) + eps) - T.reduce_sum(cm, row_sum, dim=1) - for j, k in T.Parallel(hc_mult, hc_mult): - cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) - - # comb = comb / (comb.sum(-2) + eps) - T.reduce_sum(cm, col_sum, dim=0) - for j, k in T.Parallel(hc_mult, hc_mult): - cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) - - # save comb_mix to global memory - for j, k in T.Parallel(hc_mult, hc_mult): - comb_mix[i, j * hc_mult + k] = cm[j, k] - else: - ################################################################## - # _pre_split_mixes_fwd (pre) - pre_mix_shared = T.alloc_shared(hc_mult, T.float32) - for j in T.Parallel(hc_mult): - pre_mix_shared[j] = ( - T.sigmoid( - mixes_shared[j] * hc_scale[0] + hc_base[j], - ) - + hc_pre_eps - ) - ################################################################### - # _pre_apply_mix_fwd - for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): - xs = T.alloc_shared((hc_mult, hidden_block), T.float32) - xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) - T.copy(residual[i, 0, i0_h * hidden_block], xs) - T.copy(xs, xl) - - ol = T.alloc_fragment(hidden_block, T.float32) - T.clear(ol) - - for i_hc in T.serial(hc_mult): - pre = pre_mix_shared[i_hc] - for i1_h in T.Parallel(hidden_block): - ol[i1_h] += pre * xl[i_hc, i1_h] - - T.copy(ol, layer_input[i, i0_h * hidden_block]) - T.pdl_trigger() - - -def mhc_pre( +def _mhc_pre_torch( residual: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, @@ -190,125 +18,44 @@ def mhc_pre( sinkhorn_repeat: int, n_splits: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Forward pass for mHC pre block. - - Args: - residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 - fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 - hc_scale: shape (3,), dtype torch.float32 - hc_base: shape (hc_mult3,), dtype torch.float32 - rms_eps: RMS normalization epsilon - hc_pre_eps: pre-mix epsilon - hc_sinkhorn_eps: sinkhorn epsilon - hc_post_mult_value: post-mix multiplier value - sinkhorn_repeat: number of sinkhorn iterations - n_splits: split-k factor; - - Returns: - post_mix: shape (..., hc_mult), dtype torch.float32 - comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 - layer_input: shape (..., hidden_size), dtype torch.bfloat16 - """ - - # Validate shapes - assert residual.dtype == torch.bfloat16 - assert fn.dtype == torch.float32 - assert hc_scale.dtype == torch.float32 - assert hc_base.dtype == torch.float32 - hc_mult = residual.shape[-2] hidden_size = residual.shape[-1] hc_mult2 = hc_mult * hc_mult hc_mult3 = hc_mult * 2 + hc_mult2 - - hc_hidden_size = hc_mult * hidden_size - assert fn.shape[0] == hc_mult3 - assert fn.shape[1] == hc_hidden_size - assert hc_scale.shape == (3,) - assert hc_base.shape == (hc_mult3,) - outer_shape = residual.shape[:-2] - residual_flat = residual.view(-1, hc_mult, hidden_size) - num_tokens = residual_flat.shape[0] - fn_flat = fn + x_flat = residual.reshape(-1, hc_mult * hidden_size).float() + num_tokens = x_flat.shape[0] - # these number are from deepgemm kernel impl - block_k = 64 - block_m = 64 - n_splits = compute_num_split(block_k, hc_hidden_size, cdiv(num_tokens, block_m)) + rsqrt_val = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + rms_eps) + mixes = F.linear(x_flat, fn) * rsqrt_val - post_mix = torch.empty( - num_tokens, - hc_mult, - dtype=torch.float32, - device=residual.device, - ) - comb_mix = torch.empty( - num_tokens, - hc_mult2, - dtype=torch.float32, - device=residual.device, - ) - layer_input = torch.empty( - num_tokens, - hidden_size, - dtype=torch.bfloat16, - device=residual.device, - ) + pre_logits = mixes[:, :hc_mult] + post_logits = mixes[:, hc_mult:hc_mult * 2] + comb_logits = mixes[:, hc_mult * 2:] - gemm_out_mul = torch.empty( - n_splits, - num_tokens, - hc_mult3, - dtype=torch.float32, - device=residual.device, - ) - gemm_out_sqrsum = torch.empty( - n_splits, - num_tokens, - dtype=torch.float32, - device=residual.device, - ) + pre_mix = torch.sigmoid(pre_logits * hc_scale[0] + hc_base[:hc_mult]) + hc_pre_eps + post_mix = torch.sigmoid(post_logits * hc_scale[1] + hc_base[hc_mult:hc_mult * 2]) * hc_post_mult_value + comb = comb_logits * hc_scale[2] + hc_base[hc_mult * 2:] + comb = comb.view(num_tokens, hc_mult, hc_mult) - from vllm.utils.deep_gemm import tf32_hc_prenorm_gemm + comb = comb.softmax(-1) + hc_sinkhorn_eps + comb = comb / (comb.sum(-2, keepdim=True) + hc_sinkhorn_eps) + for _ in range(sinkhorn_repeat - 1): + comb = comb / (comb.sum(-1, keepdim=True) + hc_sinkhorn_eps) + comb = comb / (comb.sum(-2, keepdim=True) + hc_sinkhorn_eps) - tf32_hc_prenorm_gemm( - residual_flat.view(num_tokens, hc_mult * hidden_size), - fn_flat, - gemm_out_mul, - gemm_out_sqrsum, - n_splits, - ) - - mhc_pre_big_fuse_tilelang( - gemm_out_mul, - gemm_out_sqrsum, - hc_scale, - hc_base, - residual_flat, - post_mix, - comb_mix, - layer_input, - hidden_size, - rms_eps, - hc_pre_eps, - hc_sinkhorn_eps, - hc_post_mult_value, - sinkhorn_repeat, - n_splits, - hc_mult, - ) + res_view = residual.reshape(num_tokens, hc_mult, hidden_size).float() + layer_input = (pre_mix.unsqueeze(-1) * res_view).sum(dim=-2).to(residual.dtype) post_mix = post_mix.view(*outer_shape, hc_mult, 1) - comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + comb_mix = comb.view(*outer_shape, hc_mult, hc_mult) layer_input = layer_input.view(*outer_shape, hidden_size) return post_mix, comb_mix, layer_input -def _mhc_pre_fake( +def mhc_pre( residual: torch.Tensor, fn: torch.Tensor, hc_scale: torch.Tensor, @@ -320,119 +67,48 @@ def _mhc_pre_fake( sinkhorn_repeat: int, n_splits: int = 1, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if _aiter_mhc is not None: + return _aiter_mhc.mhc_pre( + residual, fn, hc_scale, hc_base, + rms_eps, hc_pre_eps, hc_sinkhorn_eps, + hc_post_mult_value, sinkhorn_repeat, + ) + return _mhc_pre_torch( + residual, fn, hc_scale, hc_base, + rms_eps, hc_pre_eps, hc_sinkhorn_eps, + hc_post_mult_value, sinkhorn_repeat, n_splits, + ) + + +def _mhc_pre_fake( + residual, fn, hc_scale, hc_base, rms_eps, hc_pre_eps, + hc_sinkhorn_eps, hc_post_mult_value, sinkhorn_repeat, n_splits=1, +): hc_mult = residual.shape[-2] hidden_size = residual.shape[-1] outer_shape = residual.shape[:-2] - - # Create empty tensors with correct shapes for meta device / shape inference - post_mix = torch.empty( - *outer_shape, - hc_mult, - 1, - dtype=torch.float32, - device=residual.device, - ) - comb_mix = torch.empty( - *outer_shape, - hc_mult, - hc_mult, - dtype=torch.float32, - device=residual.device, - ) - layer_input = torch.empty( - *outer_shape, - hidden_size, - dtype=torch.bfloat16, - device=residual.device, - ) - + post_mix = torch.empty(*outer_shape, hc_mult, 1, dtype=torch.float32, device=residual.device) + comb_mix = torch.empty(*outer_shape, hc_mult, hc_mult, dtype=torch.float32, device=residual.device) + layer_input = torch.empty(*outer_shape, hidden_size, dtype=torch.bfloat16, device=residual.device) return post_mix, comb_mix, layer_input -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10, - }, -) -def mhc_post_tilelang( - a, - b, - c, - d, - x, - hc: int, - hidden: int, - n_thr: int = 128, - h_blk: int = 1024, -) -> tilelang.JITKernel: - # rename for shorter code - n = T.dynamic("num_tokens") - h = hidden - - h_blk = math.gcd(hidden, h_blk) - a: T.Tensor((n, hc, hc), T.float32) # type: ignore[no-redef, valid-type] - b: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type] - c: T.Tensor((n, hc), T.float32) # type: ignore[no-redef, valid-type] - d: T.Tensor((n, h), T.bfloat16) # type: ignore[no-redef, valid-type] - x: T.Tensor((n, hc, h), T.bfloat16) # type: ignore[no-redef, valid-type] - with T.Kernel(n, threads=n_thr) as i_n: - x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) - b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) - d_shared = T.alloc_shared(h_blk, T.bfloat16) - - x_local = T.alloc_fragment((hc, h_blk), T.float32) - b_local = T.alloc_fragment((hc, h_blk), T.float32) - d_local = T.alloc_fragment(h_blk, T.float32) - - a_local = T.alloc_fragment((hc, hc), T.float32) - c_local = T.alloc_fragment(hc, T.float32) - T.pdl_sync() - T.copy(a[i_n, 0, 0], a_local) - T.copy(c[i_n, 0], c_local) - - for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): - T.copy(b[i_n, 0, i0_h * h_blk], b_shared) - T.copy(d[i_n, i0_h * h_blk], d_shared) - - T.copy(b_shared, b_local) - T.copy(d_shared, d_local) - for i_hco, i1_h in T.Parallel(hc, h_blk): - x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] - for i_hci in T.serial(hc): - x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] - T.copy(x_local, x_shared) - - T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) - T.pdl_trigger() - - def mhc_post( x: torch.Tensor, residual: torch.Tensor, post_layer_mix: torch.Tensor, comb_res_mix: torch.Tensor, ) -> torch.Tensor: - out = torch.empty_like(residual) - mhc_post_tilelang( - comb_res_mix, - residual, - post_layer_mix.squeeze(-1), - x, - out, - residual.shape[-2], - residual.shape[-1], - ) - return out + if _aiter_mhc is not None: + out = torch.empty_like(residual) + _aiter_mhc.mhc_post(out, x, residual, post_layer_mix, comb_res_mix) + return out + out = torch.einsum("...ij,...jh->...ih", comb_res_mix, residual.float()) + out = out + post_layer_mix * x.unsqueeze(-2).float() + return out.to(residual.dtype) -def _mhc_post_fake( - x: torch.Tensor, - residual: torch.Tensor, - post_layer_mix: torch.Tensor, - comb_res_mix: torch.Tensor, -) -> torch.Tensor: +def _mhc_post_fake(x, residual, post_layer_mix, comb_res_mix): return torch.empty_like(residual) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9613b11d35e2..b54d5ee015f5 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -966,15 +966,13 @@ def requant_weight_ue8m0_inplace( def _upcast_e8m0_to_fp32(scale: torch.Tensor) -> torch.Tensor: - """Upcast E8M0 (exponent-only) scale to float32. + """Decode E8M0 (exponent-only) scale tensors to float32. - E8M0 stores only the 8-bit biased exponent (bias=127). To convert - to float32 we place those 8 bits into the exponent field of an - IEEE-754 float32 (bits 23-30) with sign=0 and mantissa=0. + E8M0 stores an unsigned exponent with IEEE-754 bias 127. Keep the + conversion in one helper so CUDA DeepGEMM and ROCm fallback paths use + identical scale semantics for checkpoints that store UE8M0 scales. """ - exp_bits = scale.view(torch.uint8).to(torch.int32) - fp32_bits = exp_bits << 23 - return fp32_bits.view(torch.float32) + return torch.exp2(scale.view(torch.uint8).to(torch.float32) - 127) def deepgemm_post_process_fp8_weight_block( @@ -1284,6 +1282,10 @@ def process_fp8_weight_block_strategy( weight=weight, weight_scale=weight_scale ) + if weight_scale.dtype == torch.float8_e8m0fnu and not is_deep_gemm_e8m0_used(): + # ROCm fallback kernels do not accept UE8M0 scale tensors directly. + weight_scale = _upcast_e8m0_to_fp32(weight_scale) + weight = _maybe_pad_fp8_weight(weight) return weight, weight_scale diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index b309bf14d991..0dba796362eb 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -125,7 +125,15 @@ def normalize_e4m3fn_to_e4m3fnuz( # the e4m3fn value, so we should double the scaling factor to # get the same dequantized value. # https://onnx.ai/onnx/technical/float8.html - weight_scale = weight_scale * 2.0 + if weight_scale.dtype in (torch.float8_e8m0fnu,): + weight_scale = weight_scale.view(torch.uint8).to(torch.float32) + weight_scale = torch.exp2(weight_scale - 127) * 2.0 + else: + weight_scale = weight_scale * 2.0 if input_scale is not None: - input_scale = input_scale * 2.0 + if input_scale.dtype in (torch.float8_e8m0fnu,): + input_scale = input_scale.view(torch.uint8).to(torch.float32) + input_scale = torch.exp2(input_scale - 127) * 2.0 + else: + input_scale = input_scale * 2.0 return weight, weight_scale, input_scale diff --git a/vllm/model_executor/layers/sparse_attn_indexer.py b/vllm/model_executor/layers/sparse_attn_indexer.py index ca82f2feb7ef..113880ae9c4d 100644 --- a/vllm/model_executor/layers/sparse_attn_indexer.py +++ b/vllm/model_executor/layers/sparse_attn_indexer.py @@ -32,6 +32,11 @@ elif current_platform.is_xpu(): from vllm._xpu_ops import xpu_ops +# Registers vllm::rocm_sparse_attn_indexer_no_insert for the DeepSeek-V4 +# layout where the compressor pre-inserts K and the indexer receives k=None. +if current_platform.is_rocm(): + import vllm.v1.attention.ops.rocm_sparse_attn_indexer # noqa: F401 + logger = init_logger(__name__) RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024 @@ -499,13 +504,25 @@ def forward_hip( k: torch.Tensor, weights: torch.Tensor, ): - assert not self.skip_k_cache_insert, ( - "AMD platform doesn't support skip cache insert yet" - ) assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet" assert isinstance(q_quant, torch.Tensor), ( "AMD sparse_attn_indexer expects a single FP8 q_quant tensor" ) + if self.skip_k_cache_insert: + return torch.ops.vllm.rocm_sparse_attn_indexer_no_insert( + hidden_states, + _encode_layer_name(self.k_cache.prefix), + self.k_cache.kv_cache, + q_quant, + weights, + self.quant_block_size, + self.scale_fmt, + self.topk_tokens, + self.head_dim, + self.max_model_len, + self.max_total_seq_len, + self.topk_indices_buffer, + ) if rocm_aiter_ops.is_enabled(): return torch.ops.vllm.rocm_aiter_sparse_attn_indexer( hidden_states, @@ -525,5 +542,6 @@ def forward_hip( else: raise RuntimeError( "Sparse attention indexer ROCm custom op requires ROCm " - "Aiter ops to be enabled." + "Aiter ops to be enabled (or skip_k_cache_insert=True for " + "the V4 no-insert layout)." ) diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index e26b511de4ce..4ee04e47b473 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -303,6 +303,14 @@ def cublas_gemm_bf16_bf16_fp32( x: torch.Tensor, weight: torch.Tensor, ): + if current_platform.is_rocm(): + try: + from aiter.ops.triton.gemm.basic.gemm_a16w16 import gemm_a16w16 + return gemm_a16w16(x.to(torch.bfloat16), weight.to(torch.bfloat16), + dtype=torch.bfloat16).to(torch.float32) + except ImportError: + return torch.mm(x.to(torch.bfloat16), + weight.t().to(torch.bfloat16)).to(torch.float32) return ops.router_gemm_bf16_fp32(x, weight) diff --git a/vllm/model_executor/models/deepseek_v4.py b/vllm/model_executor/models/deepseek_v4.py index 7733252804b7..49405e2e80cf 100644 --- a/vllm/model_executor/models/deepseek_v4.py +++ b/vllm/model_executor/models/deepseek_v4.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import typing +import os from collections.abc import Callable, Iterable from itertools import islice @@ -18,6 +19,7 @@ ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp +from vllm.model_executor.layers.deepseek_v4_debug import dump_tensor from vllm.model_executor.layers.deepseek_v4_attention import ( DeepseekV4Indexer, DeepseekV4MLAModules, @@ -57,6 +59,36 @@ from vllm.utils.multi_stream_utils import AuxStreamType from vllm.utils.torch_utils import direct_register_custom_op + +_DSV4_DEBUG_COUNTS: dict[str, int] = {} + + +def _dsv4_debug_enabled(name: str) -> bool: + value = os.environ.get(name, "") + return value not in ("", "0", "false", "False", "no", "No") + + +def _dsv4_debug_log(name: str, msg: str, limit: int = 8) -> None: + count = _DSV4_DEBUG_COUNTS.get(name, 0) + if count < limit: + print(f"[DSV4_DEBUG:{name}] {msg}", flush=True) + _DSV4_DEBUG_COUNTS[name] = count + 1 + + +def _dsv4_tensor_summary(t: torch.Tensor) -> str: + if t.numel() == 0: + return f"shape={tuple(t.shape)} dtype={t.dtype} empty" + sample = t.detach().float() if t.is_floating_point() else t.detach().to(torch.float32) + try: + finite = torch.isfinite(t.detach()).all().item() if t.is_floating_point() else True + except NotImplementedError: + finite = "unsupported" + return ( + f"shape={tuple(t.shape)} dtype={t.dtype} " + f"mean={sample.mean().item():.4e} std={sample.std(unbiased=False).item():.4e} " + f"amax={sample.abs().amax().item():.4e} finite={finite}" + ) + from .utils import ( AutoWeightsLoader, WeightsMapper, @@ -801,6 +833,7 @@ def _init_fused_moe_experts( def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: + dump_tensor(f"{self.prefix}.input", hidden_states) if self.gate.tid2eid is not None: if input_ids is None: raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") @@ -810,6 +843,7 @@ def forward( org_shape = hidden_states.shape router_logits, _ = self.gate(hidden_states) + dump_tensor(f"{self.prefix}.router_logits", router_logits) topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, gating_output=router_logits, @@ -827,17 +861,22 @@ def forward( activation_clamp = ( float(self.swiglu_limit) if self.swiglu_limit is not None else None ) + dump_tensor(f"{self.prefix}.topk_weights", topk_weights) + dump_tensor(f"{self.prefix}.topk_ids", topk_ids) final_hidden_states = self.experts( hidden_states, topk_weights, topk_ids, activation_clamp=activation_clamp, ) + dump_tensor(f"{self.prefix}.experts_out", final_hidden_states) if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) + dump_tensor(f"{self.prefix}.shared_experts_out", shared_output) final_hidden_states += shared_output + dump_tensor(f"{self.prefix}.output", final_hidden_states.view(org_shape)) return final_hidden_states.view(org_shape) def _forward_fused_moe( @@ -853,12 +892,14 @@ def _forward_fused_moe( ) else: router_logits, _ = self.gate(hidden_states) + dump_tensor(f"{self.prefix}.router_logits", router_logits) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, input_ids=input_ids, ) + dump_tensor(f"{self.prefix}.output", final_hidden_states.view(org_shape)) return final_hidden_states.view(org_shape) def finalize_mega_moe_weights(self) -> None: @@ -880,6 +921,7 @@ def __init__( layer_id = extract_layer_index(prefix) self.layer_id = layer_id + self.prefix = prefix self.hidden_size = config.hidden_size self.n_heads = config.num_attention_heads tp_size = get_tensor_model_parallel_world_size() @@ -1032,7 +1074,10 @@ def forward( hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None, ): - return self.mla_attn(positions, hidden_states, llama_4_scaling) + dump_tensor(f"{self.prefix}.input", hidden_states, layer_idx=self.layer_id) + out = self.mla_attn(positions, hidden_states, llama_4_scaling) + dump_tensor(f"{self.prefix}.output", out, layer_idx=self.layer_id) + return out class DeepseekV4DecoderLayer(nn.Module): @@ -1046,6 +1091,8 @@ def __init__( super().__init__() config = vllm_config.model_config.hf_config self.hidden_size = config.hidden_size + self.prefix = prefix + self.layer_id = extract_layer_index(prefix) self.rms_norm_eps = config.rms_norm_eps self.attn = DeepseekV4Attention( @@ -1149,21 +1196,30 @@ def forward( positions: torch.Tensor, input_ids: torch.Tensor | None, ) -> torch.Tensor: + dump_tensor(f"{self.prefix}.input", x, layer_idx=self.layer_id) residual = x x, post, comb = self.hc_pre( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) + dump_tensor(f"{self.prefix}.attn_hc_pre", x, layer_idx=self.layer_id) x = self.attn_norm(x) + dump_tensor(f"{self.prefix}.attn_norm", x, layer_idx=self.layer_id) x = self.attn(positions, x, None) + dump_tensor(f"{self.prefix}.attn_out", x, layer_idx=self.layer_id) x = self.hc_post(x, residual, post, comb) + dump_tensor(f"{self.prefix}.attn_hc_post", x, layer_idx=self.layer_id) residual = x x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) + dump_tensor(f"{self.prefix}.ffn_hc_pre", x, layer_idx=self.layer_id) x = self.ffn_norm(x) + dump_tensor(f"{self.prefix}.ffn_norm", x, layer_idx=self.layer_id) x = self.ffn(x, input_ids) + dump_tensor(f"{self.prefix}.ffn_out", x, layer_idx=self.layer_id) x = self.hc_post(x, residual, post, comb) + dump_tensor(f"{self.prefix}.output", x, layer_idx=self.layer_id) return x @@ -1257,7 +1313,9 @@ def forward( inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: hidden_states = self.embed_input_ids(input_ids) + dump_tensor("model.embed_tokens", hidden_states) hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) + dump_tensor("model.embed_tokens_repeated", hidden_states) for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( @@ -1270,15 +1328,34 @@ def forward( num_tokens = hidden_states.shape[0] self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) - hidden_states = hc_head( - hidden_states, - self.hc_head_fn, - self.hc_head_scale, - self.hc_head_base, - self.rms_norm_eps, - self.hc_eps, - ) + dump_tensor("model.pre_hc_head", hidden_states) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_LOG_STATS"): + _dsv4_debug_log("pre_hc_head", _dsv4_tensor_summary(hidden_states)) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_HC_HEAD_BYPASS"): + hidden_states = hidden_states[:, 0, :] + elif _dsv4_debug_enabled("VLLM_DEBUG_DSV4_HC_HEAD_BF16"): + hidden_states = hc_head_eager( + hidden_states, + self.hc_head_fn, + self.hc_head_scale, + self.hc_head_base, + self.rms_norm_eps, + self.hc_eps, + ) + else: + hidden_states = hc_head( + hidden_states, + self.hc_head_fn, + self.hc_head_scale, + self.hc_head_base, + self.rms_norm_eps, + self.hc_eps, + ) + dump_tensor("model.post_hc_head", hidden_states) + if _dsv4_debug_enabled("VLLM_DEBUG_DSV4_LOG_STATS"): + _dsv4_debug_log("post_hc_head", _dsv4_tensor_summary(hidden_states)) hidden_states = self.norm(hidden_states) + dump_tensor("model.final_norm", hidden_states) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -1391,6 +1468,24 @@ def finalize_mega_moe_weights(self) -> None: layer.ffn.finalize_mega_moe_weights() +def hc_head_eager( + hidden_states: torch.Tensor, + hc_fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_norm_eps: float, + hc_eps: float, +) -> torch.Tensor: + x = hidden_states + shape, dtype = x.size(), x.dtype + x = x.flatten(1).float() + rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + rms_norm_eps) + mixes = F.linear(x, hc_fn) * rsqrt + pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps + y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) + return y.to(dtype) + + @torch.compile(backend=current_platform.simple_compile_backend) def hc_head( hidden_states: torch.Tensor, @@ -1461,7 +1556,16 @@ def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: + dump_tensor("model.lm_head_input", hidden_states) logits = self.logits_processor(self.lm_head, hidden_states) + dump_tensor("model.logits", logits) + if logits is not None and _dsv4_debug_enabled("VLLM_DEBUG_DSV4_LOG_STATS"): + topv, topi = torch.topk(logits[0].float(), k=min(5, logits.shape[-1])) + _dsv4_debug_log( + "logits", + _dsv4_tensor_summary(logits) + + f" top_ids={topi.detach().cpu().tolist()} top_vals={topv.detach().cpu().tolist()}", + ) return logits def forward( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 527733386202..d653f79c85bd 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -418,6 +418,7 @@ class RocmPlatform(Platform): "torchao", "bitsandbytes", "modelopt_fp4", + "deepseek_v4_fp8", ] @classmethod diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 6b89f5c33203..094e63818ad3 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -295,9 +295,90 @@ def fp8_gemm_nt(*args, **kwargs): return _fp8_gemm_nt_impl(*args, disable_ue8m0_cast=not use_ue8m0, **kwargs) +def _dequant_fp8_block(x, scale): + import torch + orig_shape = x.shape + n = x.numel() + if scale.dtype == torch.float8_e8m0fnu: + s_f = torch.exp2(scale.view(torch.uint8).to(torch.float32) - 127.0) + else: + s_f = scale.to(torch.float32) + s_flat = s_f.reshape(-1) + n_blocks = s_flat.numel() + if n_blocks == 1: + return x.to(torch.float32) * s_flat[0] + if n % n_blocks != 0: + return x.to(torch.float32) * s_flat.mean() + block_size = n // n_blocks + s_expanded = s_flat.repeat_interleave(block_size) + x_flat = x.reshape(-1).to(torch.float32) + return (x_flat * s_expanded).reshape(orig_shape) + + +def _reshape_to_subscripts(tensor, subscripts, dim_map): + """Reshape tensor to match the number of subscript dimensions. + Uses dim_map to infer known dims; the single unknown dim is derived + from tensor.numel() / product(known_dims). + """ + target_ndim = len(subscripts) + if tensor.ndim == target_ndim: + return tensor + known = [dim_map.get(c) for c in subscripts] + known_product = 1 + unknown_count = 0 + for v in known: + if v is not None: + known_product *= v + else: + unknown_count += 1 + if unknown_count == 0: + return tensor.reshape(known) + if unknown_count == 1: + unknown_val = tensor.numel() // known_product + shape = [v if v is not None else unknown_val for v in known] + return tensor.reshape(shape) + # Cannot infer — return as-is and let einsum raise a descriptive error + return tensor + + +def _rocm_fp8_einsum_fallback(equation, a_tuple, b_tuple, out, recipe=None): + import torch + a, a_scale = a_tuple + b, b_scale = b_tuple + a_f = _dequant_fp8_block(a, a_scale) + b_f = _dequant_fp8_block(b, b_scale) + + # Parse equation "...,...->..." and fix ndim mismatches. + # Weights may be stored 2-D (e.g. [h*d, r]) while the equation + # expects 3-D (e.g. [h, d, r]). Build a dim_map from whichever + # operand already has the right number of dims, then reshape the other. + lhs, _ = equation.split('->') + a_subs, b_subs = lhs.split(',') + dim_map = {} + if a_f.ndim == len(a_subs): + for c, s in zip(a_subs, a_f.shape): + dim_map[c] = s + if b_f.ndim == len(b_subs): + for c, s in zip(b_subs, b_f.shape): + dim_map[c] = s + + a_f = _reshape_to_subscripts(a_f, a_subs, dim_map) + # Rebuild dim_map after a_f may have been reshaped + if a_f.ndim == len(a_subs): + for c, s in zip(a_subs, a_f.shape): + dim_map[c] = s + b_f = _reshape_to_subscripts(b_f, b_subs, dim_map) + + result = torch.einsum(equation, a_f, b_f) + out.copy_(result.to(out.dtype)) + + def fp8_einsum(*args, **kwargs): _lazy_init() if _fp8_einsum_impl is None: + from vllm.platforms import current_platform + if current_platform.is_rocm(): + return _rocm_fp8_einsum_fallback(*args, **kwargs) return _missing(*args, **kwargs) return _fp8_einsum_impl(*args, **kwargs) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 0254a46752c6..fd84abfb4eb7 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. @@ -101,6 +102,7 @@ def _fused_indexer_q_rope_quant_kernel( index_weights_head_scale, index_weights_out_ptr, index_weights_out_stride, + IS_FNUZ: tl.constexpr = False, ): # Layout matches the unfused reference (DeepseekV4ScalingRotaryEmbedding # + per_token_group_quant_fp8): GPT-J interleaved RoPE applied to the @@ -151,16 +153,16 @@ def _fused_indexer_q_rope_quant_kernel( if INDEX_Q_NOPE_DIM > 0: tl.store( fp8_base_ptr + nope_offset, - tl.div_rn(x_nope, index_q_scale).to(tl.float8e4nv), + tl.div_rn(x_nope, index_q_scale).to(tl.float8e4b8 if IS_FNUZ else tl.float8e4nv), ) fp8_rot_base = fp8_base_ptr + INDEX_Q_NOPE_DIM tl.store( fp8_rot_base + half_offset * 2, - tl.div_rn(r_even, index_q_scale).to(tl.float8e4nv), + tl.div_rn(r_even, index_q_scale).to(tl.float8e4b8 if IS_FNUZ else tl.float8e4nv), ) tl.store( fp8_rot_base + half_offset * 2 + 1, - tl.div_rn(r_odd, index_q_scale).to(tl.float8e4nv), + tl.div_rn(r_odd, index_q_scale).to(tl.float8e4b8 if IS_FNUZ else tl.float8e4nv), ) # FP8 weight-fold contract: @@ -211,6 +213,7 @@ def _fused_indexer_q_rope_mxfp4_kernel( index_weights_head_scale, index_weights_out_ptr, index_weights_out_stride, + IS_FNUZ: tl.constexpr = False, ): INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM @@ -391,7 +394,7 @@ def fused_indexer_q_rope_quant( index_q_scale.view(torch.int32).squeeze(-1), ), index_weights_out - index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) + index_q_fp8 = torch.empty_like(index_q, dtype=current_platform.fp8_dtype()) _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( positions, index_q, @@ -410,6 +413,7 @@ def fused_indexer_q_rope_quant( index_weights_head_scale, index_weights_out, index_weights_out.stride(0), + current_platform.fp8_dtype() == torch.float8_e4m3fnuz, num_warps=1, # TODO: Tune this ) return index_q_fp8, index_weights_out diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 97c9538889a1..ea94b137981b 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -9,6 +9,7 @@ import torch +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton @@ -35,6 +36,7 @@ def _fused_inv_rope_fp8_quant_per_head( ROPE_START: tl.constexpr, HALF_ROPE: tl.constexpr, TMA_ALIGNED_SCALES: tl.constexpr, + IS_FNUZ: tl.constexpr = False, ): # int64: stride multiply overflows int32 past num_tokens=32768 (IMA). pid_token = tl.program_id(0).to(tl.int64) @@ -103,7 +105,7 @@ def _fused_inv_rope_fp8_quant_per_head( ), (HEAD_DIM,), ) - x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4nv) + x_quant = tl.clamp(x / scales_exp, -fp8_max, fp8_max).to(tl.float8e4b8 if IS_FNUZ else tl.float8e4nv) fp8_base = ( fp8_ptr @@ -177,7 +179,7 @@ def fused_inv_rope_fp8_quant( num_scale_blocks = d // quant_group_size chunks_per_head = head_dim // quant_group_size - fp8_dtype = torch.float8_e4m3fn + fp8_dtype = current_platform.fp8_dtype() fp8_max = torch.finfo(fp8_dtype).max fp8_buf = torch.empty( @@ -223,8 +225,8 @@ def fused_inv_rope_fp8_quant( ROPE_START=nope_dim % quant_group_size, HALF_ROPE=rope_dim // 2, TMA_ALIGNED_SCALES=tma_aligned_scales, + IS_FNUZ=current_platform.fp8_dtype() == torch.float8_e4m3fnuz, num_stages=1, - launch_pdl=False, ) grid = (tma_aligned_T, n_groups * heads_per_group) diff --git a/vllm/v1/attention/ops/flashmla.py b/vllm/v1/attention/ops/flashmla.py index df04f5bf2289..c06ae0a6ab63 100644 --- a/vllm/v1/attention/ops/flashmla.py +++ b/vllm/v1/attention/ops/flashmla.py @@ -101,9 +101,24 @@ class FlashMLASchedMeta: # type: ignore[no-redef] flash_attn_varlen_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_kvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] flash_attn_varlen_qkvpacked_func = _raise_flashmla_unavailable # type: ignore[assignment] - flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment] - flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment] - get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment] + + if current_platform.is_rocm(): + # ROCm DeepSeek-V4 sparse FlashMLA fallbacks. These mirror the + # fp8_ds_mla block cache layout used by quantize_and_insert_k_cache. + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + flash_mla_sparse_fwd_rocm as flash_mla_sparse_fwd, + ) + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + flash_mla_with_kvcache_rocm as flash_mla_with_kvcache, + ) + from vllm.v1.attention.ops.rocm_flash_mla_sparse import ( + get_mla_metadata_rocm as get_mla_metadata, + ) + + else: + flash_mla_sparse_fwd = _raise_flashmla_unavailable # type: ignore[assignment] + flash_mla_with_kvcache = _raise_flashmla_unavailable # type: ignore[assignment] + get_mla_metadata = _raise_flashmla_unavailable # type: ignore[assignment] def get_mla_metadata_dense_fp8( diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 81cc489db0d8..57eacc4a4919 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -243,9 +243,11 @@ def fp8_paged_mqa_logits_torch( device=q.device, dtype=torch.float32, ) - context_lens = context_lens.tolist() + # context_lens can be 1D (B,) or 2D (B, next_n) for MTP decode. + # The last entry per row is always the full context length L_b. + context_lens = context_lens.reshape(batch_size, -1)[:, -1].tolist() for i in range(batch_size): - context_len = context_lens[i] + context_len = int(context_lens[i]) q_offsets = torch.arange(context_len - next_n, context_len, device="cuda") weight_slice = ( weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous() @@ -322,15 +324,10 @@ def rocm_fp8_paged_mqa_logits( Logits tensor of shape [B * next_n, max_model_len], dtype `torch.float32`. """ - from vllm._aiter_ops import rocm_aiter_ops - - aiter_paged_mqa_logits_module = None - if rocm_aiter_ops.is_enabled(): - aiter_paged_mqa_logits_module = paged_mqa_logits_module() - - if aiter_paged_mqa_logits_module is not None: + _paged_mod = paged_mqa_logits_module() + if _paged_mod is not None: deepgemm_fp8_paged_mqa_logits_stage1 = ( - aiter_paged_mqa_logits_module.deepgemm_fp8_paged_mqa_logits_stage1 + _paged_mod.deepgemm_fp8_paged_mqa_logits_stage1 ) batch_size, next_n, heads, _ = q_fp8.shape out_qk = torch.full( @@ -384,21 +381,30 @@ def fp8_mqa_logits_torch( Logits tensor of shape [M, N], dtype `torch.float32`. """ k_fp8, scale = kv + M = q.shape[0] seq_len_kv = k_fp8.shape[0] - k = k_fp8.to(torch.bfloat16) - q = q.to(torch.bfloat16) + # Process in query-token chunks to avoid OOM from [H, M, N] intermediate. + # With H=18, M=10240, N=10240 the full score tensor is ~7 GB; chunk=64 → ~45 MB. + k_f = k_fp8.to(torch.float32) # [N, D] (dequant via scale below) + q_f = q.to(torch.float32) # [M, H, D] + scale_f = scale.reshape(-1) # [N] per-key scale - mask_lo = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None] - ) - mask_hi = ( - torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None] - ) - mask = mask_lo & mask_hi + device = q.device + arange_kv = torch.arange(seq_len_kv, device=device) - score = torch.einsum("mhd,nd->hmn", q, k).float() * scale - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float("-inf")) + logits = torch.full([M, seq_len_kv], float("-inf"), device=device, dtype=torch.float32) + + CHUNK_M = 64 # tune down if still OOM + for m0 in range(0, M, CHUNK_M): + m1 = min(m0 + CHUNK_M, M) + q_c = q_f[m0:m1] # [mc, H, D] + score_c = torch.einsum("mhd,nd->hmn", q_c, k_f) * scale_f # [H, mc, N] + w_c = weights[m0:m1].unsqueeze(-1).transpose(0, 1) # [H, mc, 1] + logits_c = (score_c.relu() * w_c).sum(dim=0) # [mc, N] + + mask_lo = arange_kv[None, :] >= cu_seqlen_ks[m0:m1, None] + mask_hi = arange_kv[None, :] < cu_seqlen_ke[m0:m1, None] + logits[m0:m1] = logits_c.masked_fill(~(mask_lo & mask_hi), float("-inf")) return logits @@ -445,22 +451,36 @@ def rocm_fp8_mqa_logits( Logits tensor of shape [M, N], dtype `torch.float32`. """ - # TODO(ganyi): Temporarily workaround, will remove the module check and reference - # path after aiter merge this kernel into main - from vllm._aiter_ops import rocm_aiter_ops - - aiter_mqa_logits_module = None - if rocm_aiter_ops.is_enabled(): - aiter_mqa_logits_module = mqa_logits_module() - - if aiter_mqa_logits_module is not None: - fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits + _mqa_mod = mqa_logits_module() + if _mqa_mod is not None: + fp8_mqa_logits = _mqa_mod.fp8_mqa_logits k_fp8, scale = kv return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke) else: return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) +def _top_k_per_row_torch(logits: torch.Tensor, + topk_indices_buf: torch.Tensor, + topk_tokens: int, + num_rows: int) -> None: + """Pure PyTorch top-k per row. + + logits already has -inf for invalid / masked positions, so a plain + torch.topk naturally selects valid positions first. + Writes results in-place into topk_indices_buf[:num_rows, :topk_tokens]. + """ + actual_k = min(topk_tokens, logits.shape[1]) + rows = min(num_rows, topk_indices_buf.shape[0]) + _, idx = torch.topk(logits[:rows], actual_k, dim=1, + largest=True, sorted=False) + topk_indices_buf[:rows, :actual_k] = idx + if actual_k < topk_tokens: + # pad remaining slots with the first valid index (score = -inf → no effect) + topk_indices_buf[:rows, actual_k:] = idx[:, :1].expand( + rows, topk_tokens - actual_k) + + def rocm_aiter_sparse_attn_indexer_fake( hidden_states: torch.Tensor, k_cache_prefix: LayerNameType, @@ -540,13 +560,20 @@ def rocm_aiter_sparse_attn_indexer( num_tokens = slot_mapping.shape[0] k = k[:num_tokens] - ops.indexer_k_quant_and_cache( - k, - kv_cache, - slot_mapping, - quant_block_size, - scale_fmt, - ) + # When skip_k_cache_insert=True the compressor already wrote quantized K + # to kv_cache. The caller gathers the 544-byte quantized entry to satisfy + # the non-None k requirement, but we must NOT re-quantize it via + # indexer_k_quant_and_cache (head_dim=544 is not divisible by + # quant_block_size=64). Detect this by comparing k's last dim to head_dim. + k_already_cached = (k.shape[-1] != head_dim) + if not k_already_cached: + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: @@ -580,20 +607,10 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seqlen_ke, ) num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[ chunk.token_start : chunk.token_end, :topk_tokens ] - torch.ops._C.top_k_per_row_prefill( - logits, - chunk.cu_seqlen_ks, - chunk.cu_seqlen_ke, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + _top_k_per_row_torch(logits, topk_indices, topk_tokens, num_rows) if has_decode: decode_metadata = layer_attn_metadata.decode @@ -631,18 +648,8 @@ def rocm_aiter_sparse_attn_indexer( ) num_rows = logits.shape[0] - assert topk_tokens == 2048, "top_k_per_row assumes size 2048" topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens] - torch.ops._C.top_k_per_row_decode( - logits, - next_n, - decode_metadata.seq_lens, - topk_indices, - num_rows, - logits.stride(0), - logits.stride(1), - topk_tokens, - ) + _top_k_per_row_torch(logits, topk_indices, topk_tokens, num_rows) if decode_metadata.requires_padding: # if padded, we need to unpack diff --git a/vllm/v1/attention/ops/rocm_flash_mla_sparse.py b/vllm/v1/attention/ops/rocm_flash_mla_sparse.py new file mode 100644 index 000000000000..c9255c3806cd --- /dev/null +++ b/vllm/v1/attention/ops/rocm_flash_mla_sparse.py @@ -0,0 +1,677 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ROCm fallback for DeepSeek-V4's FlashMLA sparse attention kernels. + +The official FlashMLA kernels (``flash_mla_sparse_fwd`` for prefill and the +V4-extended ``flash_mla_with_kvcache`` for decode) are NVIDIA-only — they live +in the ``vllm._flashmla_C`` extension which is not built on ROCm. The wrapper in +``vllm/v1/attention/ops/flashmla.py`` raises ``RuntimeError`` for both calls on +non-CUDA platforms, which crashes DeepSeek-V4 inference at the first generation +step. + +This module provides ROCm-friendly equivalents: + +* ``flash_mla_sparse_fwd_rocm`` — sparse attention over a *bf16* KV pool. The + V4 prefill path pre-dequantizes the FP8 cache via + :func:`vllm.v1.attention.ops.deepseek_v4_ops.dequantize_and_gather_k_cache` + (Triton, works on ROCm), then feeds bf16 ``kv`` into FlashMLA. We can run the + same sparse softmax+gemm in chunked online-softmax form on top of the + dequantized KV without needing the FP8-aware kernel. + +* ``flash_mla_with_kvcache_rocm`` — decode path. Here FlashMLA reads the + FP8 ``swa_cache`` (and optionally a global compressed ``extra_k_cache``) + directly via ``is_fp8_kvcache=True``. We dequantize the requested slots on + the fly with a small Triton kernel (mirroring + ``_dequantize_and_gather_k_kernel`` but indexed by arbitrary global slot ids + instead of a block table), then run the same chunked sparse attention. + +* ``get_mla_metadata_rocm`` — returns an empty ``FlashMLASchedMeta`` stub so + the V4 SWA metadata builder can populate ``tile_sched_*`` fields without + crashing. The metadata is unused by our fallback path. + +Both attention paths use *online softmax* with a bounded ``chunk_topk`` over +the candidate axis so peak intermediate memory stays manageable even with +many query tokens × thousands of selected positions. + +Numerics notes +-------------- +* The softmax includes the per-head ``attn_sink`` logit as an extra column + whose value is dropped before the ``attn @ V`` reduction (matches FlashMLA + semantics: sink mass affects the partition function only). +* Invalid ``indices == -1`` entries are masked with ``-inf`` so they never + contribute, regardless of what we (safely) dequantize at slot 0. +* Rows where every candidate is invalid AND ``attn_sink == -inf`` produce a + zero output (we trap the all-``-inf`` case to avoid NaNs from ``exp(0)/0``). +""" +from __future__ import annotations + +import os + +import torch + +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton + +logger = init_logger(__name__) + +# --------------------------------------------------------------------------- +# Cache layout constants — must mirror +# vllm/v1/attention/ops/deepseek_v4_ops/cache_utils.py. +# --------------------------------------------------------------------------- +_FP8_DIM = 448 +_BF16_DIM = 64 +_SCALE_DIM = 8 +_QUANT_BLOCK_SIZE = 64 +_TOKEN_DATA_SIZE = _FP8_DIM + _BF16_DIM * 2 # 576 +_HEAD_DIM = _FP8_DIM + _BF16_DIM # 512 +_N_QUANT_BLOCKS = 7 # 7 real (448 // 64), 1 pad slot at index 7 + +# Chunk size for online-softmax over the candidate axis. 128 keeps memory +# small (~64 MiB for T_q=512, head_dim=512, bf16) while letting the matmul +# inside torch see enough work to be efficient. +_DEFAULT_CHUNK_TOPK = 128 + + +def _env_enabled(name: str) -> bool: + value = os.environ.get(name, "") + return value not in ("", "0", "false", "False", "no", "No") + + +def _batched_query_key(q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + if _env_enabled("VLLM_DSV4_ROCM_SPARSE_MLA_BMM"): + return torch.bmm(q, k.transpose(1, 2).contiguous()) + return torch.einsum("thd,tcd->thc", q, k) + + +def _batched_scores_value(scores: torch.Tensor, values: torch.Tensor) -> torch.Tensor: + if _env_enabled("VLLM_DSV4_ROCM_SPARSE_MLA_BMM"): + return torch.bmm(scores, values) + return torch.einsum("thc,tcd->thd", scores, values) + + +# --------------------------------------------------------------------------- +# FP8 slot dequantization (decode path). +# --------------------------------------------------------------------------- +if HAS_TRITON and current_platform.is_cuda_alike(): + + @triton.jit + def _gather_dequant_slots_kernel( + out_ptr, # (N, head_dim) bf16 + out_stride_n, + indices_ptr, # (N,) int32, -1 = invalid (still safely dequant slot 0) + k_cache_ptr, # uint8 byte buffer + block_stride, # bytes per block + cache_block_size: tl.constexpr, + fp8_dim: tl.constexpr, + bf16_dim: tl.constexpr, + scale_dim: tl.constexpr, + quant_block: tl.constexpr, + token_data_size: tl.constexpr, + head_dim: tl.constexpr, + n_quant_blocks: tl.constexpr, + N, + ): + pid = tl.program_id(0) + if pid >= N: + return + + raw_slot = tl.load(indices_ptr + pid) + # Always dequant slot >= 0 to keep the kernel branch-free; the + # caller masks invalid indices in the attention softmax. + slot = tl.maximum(raw_slot, 0) + + out_row_ptr = out_ptr + pid * out_stride_n + + block_idx = (slot // cache_block_size).to(tl.int64) + pos_in_block = slot % cache_block_size + + cache_block_ptr = k_cache_ptr + block_idx * block_stride + token_data_ptr = cache_block_ptr + pos_in_block * token_data_size + token_scale_ptr = ( + cache_block_ptr + + cache_block_size * token_data_size + + pos_in_block * scale_dim + ) + token_fp8_ptr = token_data_ptr + token_bf16_ptr = token_data_ptr + fp8_dim + + # Dequantize the 448 FP8 dims in 7 blocks of 64. + for qblock_idx in tl.static_range(n_quant_blocks): + qblock_start = qblock_idx * quant_block + if qblock_start < fp8_dim: + offsets = qblock_start + tl.arange(0, quant_block) + mask = offsets < fp8_dim + x_uint8 = tl.load(token_fp8_ptr + offsets, mask=mask, other=0) + x_fp8 = x_uint8.to(tl.float8e4nv, bitcast=True) + x_float = x_fp8.to(tl.float32) + encoded_scale = tl.load(token_scale_ptr + qblock_idx) + exponent = encoded_scale.to(tl.float32) - 127.0 + scale = tl.exp2(exponent) + x_dequant = x_float * scale + tl.store( + out_row_ptr + offsets, + x_dequant.to(tl.bfloat16), + mask=mask, + ) + + # Copy the trailing 64 bf16 dims unchanged. + bf16_out_ptr = out_row_ptr + fp8_dim + bf16_cache_bf16_ptr = token_bf16_ptr.to(tl.pointer_type(tl.bfloat16)) + for j in tl.static_range(bf16_dim // 16): + chunk_offsets = j * 16 + tl.arange(0, 16) + bf16_vals = tl.load(bf16_cache_bf16_ptr + chunk_offsets) + tl.store(bf16_out_ptr + chunk_offsets, bf16_vals) +else: + _gather_dequant_slots_kernel = None # type: ignore[assignment] + + +def _gather_dequant_slots_triton( + indices: torch.Tensor, # (N,) int32 — global slot ids, -1 allowed + k_cache: torch.Tensor, # uint8 (num_blocks, ...) — byte buffer + out: torch.Tensor, # (N, head_dim) bf16 output buffer +) -> None: + """Triton gather + UE8M0 FP8 dequant for arbitrary global slot ids.""" + assert _gather_dequant_slots_kernel is not None + assert k_cache.dtype == torch.uint8, ( + f"k_cache must be uint8 byte buffer, got {k_cache.dtype}" + ) + assert out.dtype == torch.bfloat16 + assert out.shape == (indices.shape[0], _HEAD_DIM) + assert indices.is_contiguous() + assert out.is_contiguous() + + block_stride = k_cache.stride(0) + n = indices.shape[0] + if n == 0: + return + + # Block size in *tokens*. The cache is shaped (num_blocks, block_size, 584) + # in the metadata, so dim 1 is the token count per block. + if k_cache.dim() >= 2: + cache_block_size = k_cache.shape[1] + else: + # 1D byte buffer; assume 64 (the default DeepSeek block size). + cache_block_size = 64 + + _gather_dequant_slots_kernel[(n,)]( + out, + out.stride(0), + indices, + k_cache, + block_stride, + cache_block_size=cache_block_size, + fp8_dim=_FP8_DIM, + bf16_dim=_BF16_DIM, + scale_dim=_SCALE_DIM, + quant_block=_QUANT_BLOCK_SIZE, + token_data_size=_TOKEN_DATA_SIZE, + head_dim=_HEAD_DIM, + n_quant_blocks=_N_QUANT_BLOCKS, + N=n, + ) + + +def _gather_dequant_slots_torch( + indices: torch.Tensor, + k_cache: torch.Tensor, + out: torch.Tensor, +) -> None: + """Pure-torch reference for ``_gather_dequant_slots_triton``. + + Slow but correct — useful for environments without a Triton runtime and + for unit-style sanity checks. Implements the same UE8M0 FP8 dequant + bf16 + copy as the Triton kernel. + """ + assert k_cache.dtype == torch.uint8 + assert out.dtype == torch.bfloat16 + n = indices.shape[0] + if n == 0: + return + + block_stride = k_cache.stride(0) + cache_block_size = k_cache.shape[1] if k_cache.dim() >= 2 else 64 + flat_cache = k_cache.view(torch.uint8).contiguous().view(-1) + + safe = indices.clamp(min=0).to(torch.int64) + block_idx = safe // cache_block_size + pos_in_block = safe % cache_block_size + + # Per-token base byte offsets for the data and scale regions. + base = block_idx * block_stride + data_base = base + pos_in_block * _TOKEN_DATA_SIZE # (N,) + scale_base = ( + base + cache_block_size * _TOKEN_DATA_SIZE + pos_in_block * _SCALE_DIM + ) # (N,) + + # ---- FP8 NoPE (448 dims) ---- + fp8_offsets = data_base.unsqueeze(-1) + torch.arange( + _FP8_DIM, device=indices.device, dtype=torch.int64 + ) + fp8_bytes = flat_cache[fp8_offsets.flatten()].view(n, _FP8_DIM) + fp8_vals = fp8_bytes.view(torch.float8_e4m3fn).to(torch.float32) + + # 7 UE8M0 scales, 1 byte each. + scale_offsets = scale_base.unsqueeze(-1) + torch.arange( + _N_QUANT_BLOCKS, device=indices.device, dtype=torch.int64 + ) + scale_bytes = flat_cache[scale_offsets.flatten()].view(n, _N_QUANT_BLOCKS) + exponents = scale_bytes.to(torch.float32) - 127.0 + scales = torch.exp2(exponents) # (N, 7) + # Repeat each scale across its 64-element block. + scales_per_dim = scales.repeat_interleave(_QUANT_BLOCK_SIZE, dim=-1) + nope = (fp8_vals * scales_per_dim).to(torch.bfloat16) + + # ---- BF16 RoPE (64 dims) ---- + bf16_byte_offsets = ( + data_base + _FP8_DIM + ).unsqueeze(-1) + torch.arange( + _BF16_DIM * 2, device=indices.device, dtype=torch.int64 + ) + bf16_bytes = flat_cache[bf16_byte_offsets.flatten()].view(n, _BF16_DIM * 2) + rope = bf16_bytes.view(torch.bfloat16).view(n, _BF16_DIM) + + out.copy_(torch.cat([nope, rope], dim=-1)) + + +def _gather_dequant_slots( + indices: torch.Tensor, + k_cache: torch.Tensor, + out: torch.Tensor, +) -> None: + """Dispatch to Triton when available, otherwise pure torch.""" + if _gather_dequant_slots_kernel is not None and indices.is_cuda: + _gather_dequant_slots_triton(indices, k_cache, out) + else: + _gather_dequant_slots_torch(indices, k_cache, out) + + +# --------------------------------------------------------------------------- +# Sparse attention with online softmax (chunked over the candidate axis). +# --------------------------------------------------------------------------- +def _online_softmax_init( + t_q: int, + num_heads: int, + head_dim_v: int, + attn_sink: torch.Tensor | None, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Seed the (m, l, O) running state with the per-head ``attn_sink`` logit. + + The sink contributes mass exp(sink) to the partition function but no V + contribution, so we initialize: + m = sink (or -inf if no sink) + l = exp(sink - m) = 1 (or 0 if sink == -inf) + O = 0 + """ + if attn_sink is not None: + sink = attn_sink.to(torch.float32).view(1, num_heads).expand(t_q, num_heads) + m = sink.contiguous() + else: + m = torch.full((t_q, num_heads), float("-inf"), dtype=torch.float32, device=device) + + finite_sink = torch.isfinite(m) + l = torch.where(finite_sink, torch.ones_like(m), torch.zeros_like(m)) + O = torch.zeros((t_q, num_heads, head_dim_v), dtype=torch.float32, device=device) + return m, l, O + + +def _online_softmax_update( + m: torch.Tensor, # (T_q, H) running max + l: torch.Tensor, # (T_q, H) running denominator + O: torch.Tensor, # (T_q, H, head_dim_v) running output (fp32) + scores: torch.Tensor, # (T_q, H, c) new logits (fp32, -inf for invalid) + V_chunk: torch.Tensor, # (T_q, c, head_dim_v) bf16/fp32 V values +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """One online-softmax step. + + Numerical care: when a row's running max stays ``-inf`` (no candidate yet + finite) we keep ``O = 0`` and ``l = 0`` and just track the new max so the + next chunk can rebase from it. + """ + chunk_max = scores.amax(dim=-1) # (T_q, H) + new_m = torch.maximum(m, chunk_max) # (T_q, H) + + # Avoid -inf - -inf = nan when both old and new max are still -inf. + finite_old = torch.isfinite(m) & torch.isfinite(new_m) # (T_q, H) + scale_old = torch.where( + finite_old, + torch.exp(m - torch.where(finite_old, new_m, m)), + torch.zeros_like(m), + ) # (T_q, H) + + # Per-element diff: -inf - finite = -inf; finite - -inf would blow up so + # only subtract when new_m is finite. Keep the 2D mask for building + # ``safe_new_m`` (same shape as ``new_m``); unsqueeze separately for the + # 3D mask used against ``scores``. + finite_new_2d = torch.isfinite(new_m) # (T_q, H) + safe_new_m = torch.where( + finite_new_2d, new_m, torch.zeros_like(new_m) + ).unsqueeze(-1) # (T_q, H, 1) + finite_new_3d = finite_new_2d.unsqueeze(-1) # (T_q, H, 1) + e_scores = torch.where( + finite_new_3d & torch.isfinite(scores), + torch.exp(scores - safe_new_m), + torch.zeros_like(scores), + ) # (T_q, H, c) + + l_new = l * scale_old + e_scores.sum(dim=-1) # (T_q, H) + # O_new = scale_old * O + e_scores @ V_chunk + O_new = O * scale_old.unsqueeze(-1) + _batched_scores_value( + e_scores, V_chunk.to(torch.float32) + ) # (T_q, H, head_dim_v) + return new_m, l_new, O_new + + +def _sparse_attn_chunked( + q: torch.Tensor, # (T_q, H, head_dim) bf16/fp32 + indices: torch.Tensor, # (T_q, max_topk) int32, -1 for invalid + K_provider, # callable: (idx_chunk: (T_q, c) int32) -> (T_q, c, head_dim) bf16 + sm_scale: float, + attn_sink: torch.Tensor | None, + head_dim_v: int, + chunk_topk: int = _DEFAULT_CHUNK_TOPK, +) -> torch.Tensor: + """Generic sparse attention with online softmax. + + ``K_provider`` is a callable that returns the dequantized K (bf16) for a + chunk of candidate indices. This lets the same attention loop drive both + the prefill path (already-dequantized bf16 KV pool, simple ``K_full[idx]`` + gather) and the decode path (per-slot Triton FP8 dequant). + """ + t_q, num_heads, _ = q.shape + max_topk = indices.shape[-1] + device = q.device + + m, l, O = _online_softmax_init(t_q, num_heads, head_dim_v, attn_sink, device) + q_f = q.to(torch.float32) + + for cs in range(0, max_topk, chunk_topk): + ce = min(cs + chunk_topk, max_topk) + idx_chunk = indices[:, cs:ce].contiguous() # (T_q, c) + valid = idx_chunk >= 0 # (T_q, c) + if not valid.any(): + continue + + K_chunk = K_provider(idx_chunk) # (T_q, c, head_dim) bf16 + + scores = _batched_query_key( + q_f, K_chunk.to(torch.float32) + ) * sm_scale # (T_q, H, c) + scores = scores.masked_fill(~valid.unsqueeze(1), float("-inf")) + + V_chunk = K_chunk[..., :head_dim_v] + m, l, O = _online_softmax_update(m, l, O, scores, V_chunk) + + # Finalize: divide by total partition function. + finite_l = l > 0 + out_f = torch.where( + finite_l.unsqueeze(-1), O / l.clamp_min(1e-30).unsqueeze(-1), torch.zeros_like(O) + ) + return out_f + + +# --------------------------------------------------------------------------- +# Prefill: K is already dequantized to bf16 by the caller. +# --------------------------------------------------------------------------- +def flash_mla_sparse_fwd_rocm( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + attn_sink: torch.Tensor | None = None, + topk_length: torch.Tensor | None = None, + out: torch.Tensor | None = None, + head_dim_v: int | None = None, + chunk_topk: int = _DEFAULT_CHUNK_TOPK, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """ROCm fallback for ``flash_mla_sparse_fwd``. + + Args: + q: ``(s_q, h_q, d)`` bf16 query. + kv: ``(s_kv, 1, d)`` bf16 KV pool (already dequantized + gathered). + indices: ``(s_q, 1, topk)`` int32 with -1 sentinel for invalid slots. + sm_scale: softmax scale factor. + attn_sink: optional ``(h_q,)`` per-head sink logit (fp32). + topk_length: kept for API parity; we use the -1 sentinel for masking. + out: optional ``(s_q, h_q, d_v_or_d)`` bf16 output buffer. + head_dim_v: V head dim (default = ``out.shape[-1]`` or ``d``). + + Returns ``(out, max_logits, lse)`` matching the upstream signature; the + optional aux outputs are ``None`` since the caller only reads ``out``. + """ + assert kv.dim() == 3 and kv.shape[1] == 1, ( + f"kv must be (s_kv, 1, d), got {kv.shape}" + ) + assert indices.dim() == 3 and indices.shape[1] == 1, ( + f"indices must be (s_q, 1, topk), got {indices.shape}" + ) + + t_q, num_heads, head_dim = q.shape + if head_dim_v is None: + head_dim_v = out.shape[-1] if out is not None else head_dim + head_dim_v = min(head_dim_v, head_dim) + + K = kv.squeeze(1) # (N_kv, d) + idx_2d = indices.squeeze(1) # (T_q, max_topk) + if topk_length is not None: + lens = topk_length.to(torch.long).view(-1, 1) + arange = torch.arange(idx_2d.shape[-1], device=idx_2d.device).view(1, -1) + idx_2d = idx_2d.masked_fill(arange >= lens, -1) + + def K_provider(idx_chunk: torch.Tensor) -> torch.Tensor: + safe = idx_chunk.clamp(min=0).to(torch.int64) + return K[safe] + + out_f = _sparse_attn_chunked( + q=q, + indices=idx_2d, + K_provider=K_provider, + sm_scale=sm_scale, + attn_sink=attn_sink, + head_dim_v=head_dim_v, + chunk_topk=chunk_topk, + ) + + if out is None: + out = torch.empty(t_q, num_heads, head_dim_v, dtype=q.dtype, device=q.device) + out[..., :head_dim_v].copy_(out_f.to(out.dtype)) + if out.shape[-1] > head_dim_v: + out[..., head_dim_v:].zero_() + return out, None, None + + +# --------------------------------------------------------------------------- +# Decode: K cache is FP8-packed; dequantize requested slots on the fly. +# --------------------------------------------------------------------------- +def _gather_chunk_to_bf16( + idx_chunk: torch.Tensor, # (T_q, c) int32 + k_cache: torch.Tensor, # uint8 byte buffer +) -> torch.Tensor: + """Dequantize `(T_q, c)` cache slots into a `(T_q, c, head_dim)` bf16 + tensor.""" + t_q, c = idx_chunk.shape + flat_idx = idx_chunk.reshape(-1).to(torch.int32).contiguous() + flat_out = torch.empty( + (flat_idx.shape[0], _HEAD_DIM), + dtype=torch.bfloat16, + device=idx_chunk.device, + ) + _gather_dequant_slots(flat_idx, k_cache, flat_out) + return flat_out.view(t_q, c, _HEAD_DIM) + + +def flash_mla_with_kvcache_rocm( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: torch.Tensor | None = None, + head_dim_v: int = _HEAD_DIM, + tile_scheduler_metadata: object | None = None, + cache_seqlens: torch.Tensor | None = None, + is_fp8_kvcache: bool = True, + indices: torch.Tensor | None = None, + topk_length: torch.Tensor | None = None, + softmax_scale: float | None = None, + attn_sink: torch.Tensor | None = None, + extra_k_cache: torch.Tensor | None = None, + extra_indices_in_kvcache: torch.Tensor | None = None, + extra_topk_length: torch.Tensor | None = None, + out: torch.Tensor | None = None, + causal: bool = False, + chunk_topk: int = _DEFAULT_CHUNK_TOPK, + **_unused_kwargs, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ROCm fallback for V4-extended ``flash_mla_with_kvcache``. + + Decodes one query token per batch position by sparse attention over up to + two FP8-packed caches: + + * ``k_cache`` + ``indices`` / ``topk_length`` (SWA) + * ``extra_k_cache`` + ``extra_indices_in_kvcache`` / ``extra_topk_length`` + (global compressed cache, optional — only present on layers with + ``compress_ratio > 1``) + + The two index sets are concatenated into a single virtual KV pool with a + chunked online softmax that includes the per-head ``attn_sink``. + + Args mirror the V4 call site in ``deepseek_v4_attention._forward_decode``. + Unused-on-ROCm kwargs (``tile_scheduler_metadata``, ``cache_seqlens``, + ``num_splits``, ``causal``) are accepted for API compatibility. + """ + del tile_scheduler_metadata, cache_seqlens, block_table, causal + + assert is_fp8_kvcache, ( + "rocm flash_mla_with_kvcache fallback requires is_fp8_kvcache=True " + "(DeepSeek-V4 always quantizes KV cache to UE8M0 FP8)" + ) + assert indices is not None, "SWA indices must be provided for V4 decode" + assert q.dim() == 4 and q.shape[1] == 1, ( + f"q must be (batch, 1, num_heads, head_dim), got {q.shape}" + ) + assert indices.dim() == 3 and indices.shape[1] == 1, ( + f"indices must be (batch, 1, max_swa_topk), got {indices.shape}" + ) + + batch_size, _, num_heads, head_dim = q.shape + if softmax_scale is None: + softmax_scale = head_dim ** -0.5 + head_dim_v = min(head_dim_v, head_dim) + + q_2d = q.squeeze(1) # (batch, H, head_dim) + swa_idx = indices.squeeze(1) # (batch, max_swa_topk) + if topk_length is not None: + lens = topk_length.to(torch.long).view(-1, 1) + arange = torch.arange(swa_idx.shape[-1], device=swa_idx.device).view(1, -1) + swa_idx = swa_idx.masked_fill(arange >= lens, -1) + + if extra_k_cache is not None: + assert extra_indices_in_kvcache is not None + assert extra_indices_in_kvcache.dim() == 3 + extra_idx = extra_indices_in_kvcache.squeeze(1) # (batch, max_extra_topk) + if extra_topk_length is not None: + lens = extra_topk_length.to(torch.long).view(-1, 1) + arange = torch.arange(extra_idx.shape[-1], device=extra_idx.device).view(1, -1) + extra_idx = extra_idx.masked_fill(arange >= lens, -1) + else: + extra_idx = None + + # Concatenate SWA + extra index sets into one virtual pool. Each pool gets + # its own dequantization closure; the index encoding tags which pool. + swa_topk = swa_idx.shape[-1] + extra_topk = extra_idx.shape[-1] if extra_idx is not None else 0 + total_topk = swa_topk + extra_topk + + # Build a single (batch, total_topk) index tensor where the second half is + # offset by a sentinel so the dispatcher can route to the right cache. + # Encoding: pool 0 = SWA (raw index), pool 1 = extra (index + 2^30). + # We rely on a closure capturing the boundary instead of bit-twiddling + # so torch.int32 stays clean. + if extra_idx is None: + combined_idx = swa_idx + else: + combined_idx = torch.cat([swa_idx, extra_idx], dim=-1) + + # Carry which slice maps to which cache by partitioning chunks at the + # SWA/extra boundary inside the loop. To keep _sparse_attn_chunked + # generic, do the chunking here and call the online-softmax primitives. + device = q.device + m, l, O = _online_softmax_init(batch_size, num_heads, head_dim_v, attn_sink, device) + q_f = q_2d.to(torch.float32) + + def step(idx_chunk: torch.Tensor, cache: torch.Tensor) -> None: + nonlocal m, l, O + valid = idx_chunk >= 0 + if not valid.any(): + return + K_chunk = _gather_chunk_to_bf16(idx_chunk, cache) + scores = _batched_query_key( + q_f, K_chunk.to(torch.float32) + ) * softmax_scale + scores = scores.masked_fill(~valid.unsqueeze(1), float("-inf")) + V_chunk = K_chunk[..., :head_dim_v] + m, l, O = _online_softmax_update(m, l, O, scores, V_chunk) + + # Pool 0: SWA cache. + for cs in range(0, swa_topk, chunk_topk): + ce = min(cs + chunk_topk, swa_topk) + step(swa_idx[:, cs:ce].contiguous(), k_cache) + + # Pool 1: extra (global compressed) cache. + if extra_idx is not None: + for cs in range(0, extra_topk, chunk_topk): + ce = min(cs + chunk_topk, extra_topk) + step(extra_idx[:, cs:ce].contiguous(), extra_k_cache) + + finite_l = l > 0 + out_f = torch.where( + finite_l.unsqueeze(-1), + O / l.clamp_min(1e-30).unsqueeze(-1), + torch.zeros_like(O), + ) + + if out is None: + out = torch.empty( + (batch_size, 1, num_heads, head_dim_v), + dtype=q.dtype, + device=q.device, + ) + out_view = out.squeeze(1) + out_view[..., :head_dim_v].copy_(out_f.to(out.dtype)) + if out_view.shape[-1] > head_dim_v: + out_view[..., head_dim_v:].zero_() + + # Upstream returns (out, softmax_lse). LSE isn't consumed by the V4 caller. + return out, None + + +# --------------------------------------------------------------------------- +# Stubs for FlashMLA's planner-side helpers. +# --------------------------------------------------------------------------- +class _FlashMLASchedMetaStub: + """Placeholder ``FlashMLASchedMeta`` for ROCm. + + The real CUDA struct holds tile-scheduler tensors that are populated by + the in-kernel planner on first use. Our fallback ignores it but the V4 + metadata builder still allocates one per layer type. + """ + + have_initialized: bool = False + tile_scheduler_metadata: torch.Tensor | None = None + num_splits: torch.Tensor | None = None + + +def get_mla_metadata_rocm(*_args, **_kwargs) -> tuple[_FlashMLASchedMetaStub, None]: + """ROCm stub for FlashMLA's ``get_mla_metadata``. + + Returns a fresh empty scheduler-metadata struct so the V4 + ``DeepseekSparseSWAMetadataBuilder.build_tile_scheduler`` can populate + its per-layer-type cache without crashing on platforms without FlashMLA. + """ + return _FlashMLASchedMetaStub(), None + + +__all__ = [ + "flash_mla_sparse_fwd_rocm", + "flash_mla_with_kvcache_rocm", + "get_mla_metadata_rocm", +] diff --git a/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py b/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py new file mode 100644 index 000000000000..5935323d1f80 --- /dev/null +++ b/vllm/v1/attention/ops/rocm_sparse_attn_indexer.py @@ -0,0 +1,549 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""ROCm path for DeepSeek-V4's sparse attention indexer when the compressor +has already inserted the (compressed) K into the indexer's KV cache, i.e. +``skip_k_cache_insert=True`` and the call site passes ``k=None``. + +The CUDA implementation in ``vllm/model_executor/layers/sparse_attn_indexer.py`` +relies on DeepGEMM's ``fp8_fp4_mqa_logits`` / ``fp8_fp4_paged_mqa_logits`` +which are NVIDIA-only. The existing ROCm AITER op +(``rocm_aiter_sparse_attn_indexer`` in ``rocm_aiter_mla_sparse.py``) always +performs its own ``indexer_k_quant_and_cache`` call and dereferences ``k``, +so it can't be reused for the V4 layout where the compressor pre-inserts K +and returns ``None``. + +This module fills that gap with: + * A streaming Triton MQA-logits kernel that runs on gfx9xx, computing + logits without materializing the (H, M, N) intermediate that the torch + reference does (which would OOM at long context). + * A torch fallback (``_mqa_logits_torch_inplace``) used for smoke tests and + on platforms without a usable Triton runtime. + * The orchestration (``rocm_sparse_attn_indexer_no_insert``) that mirrors + the CUDA ``sparse_attn_indexer`` body but skips the K-insert and uses + only ROCm-available helper ops (``cp_gather_indexer_k_quant_cache``, + ``top_k_per_row_prefill`` / ``top_k_per_row_decode``). +""" +from __future__ import annotations + +import torch + +import vllm.envs as envs +from vllm.forward_context import get_forward_context +from vllm.platforms import current_platform +from vllm.triton_utils import HAS_TRITON, tl, triton +from vllm.utils.torch_utils import ( + LayerNameType, + _resolve_layer_name, + direct_register_custom_op, +) +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata +from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton +from vllm.v1.worker.workspace import current_workspace_manager + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops + + +# Reuse the gather-workspace helper from the CUDA module so the workspace +# layout (and therefore the size estimate during profile_run) is shared. +def _gather_workspace_shapes_fp8( + total_seq_lens: int, + head_dim: int, + fp8_dtype: torch.dtype, +) -> tuple[ + tuple[tuple[int, int], torch.dtype], tuple[tuple[int, int], torch.dtype] +]: + """FP8 path layout used by ``cp_gather_indexer_k_quant_cache``: a flat + ``(T, head_dim)`` FP8 values buffer and a ``(T, 4)`` uint8 buffer that + aliases ``(T, 1)`` float32 dequant scales (one scale per token block). + Mirrors the FP8 branch of ``_gather_workspace_shapes`` in + ``sparse_attn_indexer.py``. + """ + return ( + ((total_seq_lens, head_dim), fp8_dtype), + ((total_seq_lens, 4), torch.uint8), + ) + + +# --------------------------------------------------------------------------- +# Triton MQA-logits kernel (prefill / chunked path). +# +# Computes `logits[m, n] = scale[n] * sum_h weights[m, h] * relu(q[m,h,:] . k[n,:])` +# without materializing the (H, M, N) intermediate. Streams over heads so the +# only per-program memory is (BLOCK_N,) accumulator + (D,) Q + (BLOCK_N, D) K. +# --------------------------------------------------------------------------- +if HAS_TRITON: + + @triton.jit + def _mqa_logits_prefill_kernel( + q_ptr, # (M, H, D) fp8 + weights_ptr, # (M, H) fp32 + k_ptr, # (N, D) fp8 + k_scale_ptr, # (N,) fp32 + cu_seqlen_ks_ptr, # (M,) int32 + cu_seqlen_ke_ptr, # (M,) int32 + logits_ptr, # (M, N) fp32 (output) + stride_qm, + stride_qh, + stride_qd, + stride_wm, + stride_wh, + stride_kn, + stride_kd, + stride_lm, + stride_ln, + M, + N, + H: tl.constexpr, + D: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + n_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + in_bounds = n_offsets < N + + ks = tl.load(cu_seqlen_ks_ptr + pid_m) + ke = tl.load(cu_seqlen_ke_ptr + pid_m) + valid = in_bounds & (n_offsets >= ks) & (n_offsets < ke) + + d_offsets = tl.arange(0, D) + + # Load K block once and reuse across heads: (BLOCK_N, D) fp32. + k_block = tl.load( + k_ptr + + n_offsets[:, None] * stride_kn + + d_offsets[None, :] * stride_kd, + mask=valid[:, None], + other=0.0, + ).to(tl.float32) + + accum = tl.zeros([BLOCK_N], dtype=tl.float32) + + for h in range(H): + q = tl.load( + q_ptr + + pid_m * stride_qm + + h * stride_qh + + d_offsets * stride_qd, + ).to(tl.float32) + w = tl.load(weights_ptr + pid_m * stride_wm + h * stride_wh).to( + tl.float32 + ) + + score = tl.sum(k_block * q[None, :], axis=1) + accum += w * tl.maximum(score, 0.0) + + k_scale = tl.load(k_scale_ptr + n_offsets, mask=valid, other=0.0) + logits = accum * k_scale + logits = tl.where(valid, logits, float("-inf")) + + tl.store( + logits_ptr + pid_m * stride_lm + n_offsets * stride_ln, + logits, + mask=in_bounds, + ) + + +def _mqa_logits_triton( + q_fp8: torch.Tensor, # (M, H, D) + k_fp8: torch.Tensor, # (N, D) + k_scale: torch.Tensor, # (N,) fp32 + weights: torch.Tensor, # (M, H) fp32 + cu_seqlen_ks: torch.Tensor, # (M,) int32 + cu_seqlen_ke: torch.Tensor, # (M,) int32 +) -> torch.Tensor: + M, H, D = q_fp8.shape + N = k_fp8.shape[0] + assert k_fp8.shape[1] == D + assert weights.shape == (M, H) + assert k_scale.shape == (N,) + + logits = torch.empty((M, N), dtype=torch.float32, device=q_fp8.device) + BLOCK_N = 64 + + grid = (M, triton.cdiv(N, BLOCK_N)) + _mqa_logits_prefill_kernel[grid]( + q_fp8, + weights, + k_fp8, + k_scale, + cu_seqlen_ks, + cu_seqlen_ke, + logits, + q_fp8.stride(0), + q_fp8.stride(1), + q_fp8.stride(2), + weights.stride(0), + weights.stride(1), + k_fp8.stride(0), + k_fp8.stride(1), + logits.stride(0), + logits.stride(1), + M, + N, + H=H, + D=D, + BLOCK_N=BLOCK_N, + ) + return logits + + +def _mqa_logits_torch( + q_fp8: torch.Tensor, # (M, H, D) + k_fp8: torch.Tensor, # (N, D) + k_scale: torch.Tensor, # (N,) fp32 + weights: torch.Tensor, # (M, H) fp32 + cu_seqlen_ks: torch.Tensor, # (M,) int32 + cu_seqlen_ke: torch.Tensor, # (M,) int32 +) -> torch.Tensor: + """Reference impl mirroring ``fp8_mqa_logits_torch`` (DeepGEMM test). Only + used for unit tests; production should always go through the Triton path + because this materializes a (H, M, N) fp32 intermediate. + """ + N = k_fp8.shape[0] + q = q_fp8.to(torch.bfloat16) + k = k_fp8.to(torch.bfloat16) + + arange_n = torch.arange(N, device=q.device) + mask = (arange_n[None, :] >= cu_seqlen_ks[:, None]) & ( + arange_n[None, :] < cu_seqlen_ke[:, None] + ) + + # (H, M, N) fp32; relu must be applied per-head BEFORE the weighted sum. + score = torch.einsum("mhd,nd->hmn", q, k).float() * k_scale + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + return logits + + +def _mqa_logits( + q_fp8: torch.Tensor, + k_fp8: torch.Tensor, + k_scale: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Dispatch to the Triton kernel when available; fall back to torch + reference for environments without a working Triton runtime.""" + if HAS_TRITON: + return _mqa_logits_triton( + q_fp8, k_fp8, k_scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + return _mqa_logits_torch( + q_fp8, k_fp8, k_scale, weights, cu_seqlen_ks, cu_seqlen_ke + ) + + +def _mqa_logits_paged_torch( + q_fp8: torch.Tensor, # (B, next_n, H, D) + kv_cache_4d: torch.Tensor, # (num_blocks, block_size, 1, D + scale_pad) + weights: torch.Tensor, # (B*next_n, H) fp32 + context_lens: torch.Tensor, # (B,) int32 (or (B, next_n)) + block_tables: torch.Tensor, # (B, max_blocks) int32 + max_model_len: int, + head_dim: int, +) -> torch.Tensor: + """Per-batch torch implementation of the paged MQA-logits compute. Walks + each batch element's block_table, dequantizes the FP8 K-cache slot, and + accumulates per-head relu-weighted logits. Slow but correct, and only + materializes one block's worth of intermediate at a time. + + Mirrors ``fp8_paged_mqa_logits_torch`` in ``rocm_aiter_mla_sparse.py`` but + keeps the (H, ...) intermediate scoped to a single block. + """ + from vllm.utils.math_utils import cdiv + + fp8_dtype = current_platform.fp8_dtype() + batch_size, next_n, H, D = q_fp8.shape + + # Cache layout: last dim = D fp8 + 4 byte (1 fp32) scale per token. + kv_values = kv_cache_4d[..., :head_dim] # uint8 + kv_scale = kv_cache_4d[..., head_dim:] # uint8 (4 bytes per slot) + + num_block, block_size, _, _ = kv_values.size() + + logits = torch.full( + (batch_size * next_n, max_model_len), + float("-inf"), + device=q_fp8.device, + dtype=torch.float32, + ) + + # Normalize context_lens to (B,). + if context_lens.dim() == 2: + context_lens_b = context_lens[:, 0] + else: + context_lens_b = context_lens + ctx_lens = context_lens_b.tolist() + + q_bf16 = q_fp8.to(torch.bfloat16) + weights_f32 = weights.to(torch.float32) + + for i in range(batch_size): + ctx_len = ctx_lens[i] + if ctx_len <= 0: + continue + # Per-token weight slice for this batch element. + # weight_slice shape: (H, next_n) + weight_slice = ( + weights_f32[i * next_n : (i + 1) * next_n, :] + .transpose(0, 1) + .contiguous() + ) + + for block_rk in range(cdiv(ctx_len, block_size)): + phys_block = int(block_tables[i, block_rk].item()) + # K block: (block_size, D) bf16 = fp8 dequant * fp32 scale. + k_fp8_block = ( + kv_values[phys_block, :, 0, :] + .view(fp8_dtype) + .to(torch.bfloat16) + ) + k_scale_block = ( + kv_scale[phys_block, :, 0, :].contiguous().view(torch.float32) + ) # (block_size, 1) + k_block_bf16 = k_fp8_block * k_scale_block.to(torch.bfloat16) + + # Compute (H, next_n, block_size) scores in fp32. + qx = q_bf16[i] # (next_n, H, D) + score = ( + torch.einsum("nhd,sd->hns", qx, k_block_bf16).float() + ) + + # Per-head relu before weighting. weight_slice: (H, next_n) + score = score.relu() * weight_slice.unsqueeze(-1) + block_logits = score.sum(dim=0) # (next_n, block_size) + + # Mask k positions beyond ctx_len within this block. + n_start = block_rk * block_size + n_end = min((block_rk + 1) * block_size, ctx_len) + valid = n_end - n_start + if valid <= 0: + continue + logits[ + i * next_n : (i + 1) * next_n, + n_start:n_end, + ] = block_logits[:, :valid] + + return logits + + +# --------------------------------------------------------------------------- +# Custom op: orchestration that mirrors the CUDA sparse_attn_indexer body +# but assumes ``skip_k_cache_insert=True`` (the V4 layout) and uses only +# ROCm-available helpers. +# --------------------------------------------------------------------------- +def rocm_sparse_attn_indexer_no_insert( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + attn_metadata = get_forward_context().attn_metadata + fp8_dtype = current_platform.fp8_dtype() + k_cache_prefix = _resolve_layer_name(k_cache_prefix) + + # Profile-run path: no real attn_metadata; just reserve workspace and + # the dummy logits buffer for the memory profiler (matches the shape / + # dtype the runtime path will actually use). + if not isinstance(attn_metadata, dict): + values_spec, scales_spec = _gather_workspace_shapes_fp8( + total_seq_lens, head_dim, fp8_dtype + ) + current_workspace_manager().get_simultaneous(values_spec, scales_spec) + max_logits_elems = ( + envs.VLLM_SPARSE_INDEXER_MAX_LOGITS_MB * 1024 * 1024 + ) + _ = torch.empty( + max_logits_elems, + dtype=torch.uint8, + device=hidden_states.device, + ) + if topk_indices_buffer is None: + return torch.empty( + (hidden_states.shape[0], topk_tokens), + dtype=torch.int32, + device=hidden_states.device, + ) + return topk_indices_buffer + + layer_attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata) + assert topk_indices_buffer is not None + + has_decode = layer_attn_metadata.num_decodes > 0 + has_prefill = layer_attn_metadata.num_prefills > 0 + num_decode_tokens = layer_attn_metadata.num_decode_tokens + + # NOTE: K-cache insert is INTENTIONALLY skipped here. DeepSeek-V4's + # compressor (DeepseekCompressor.forward) writes the compressed K to the + # indexer's KV cache via its fused triton kernel before this op is called, + # and the call site passes k=None. + + topk_indices_buffer[: hidden_states.shape[0]] = -1 + + if has_prefill: + prefill_metadata = layer_attn_metadata.prefill + assert prefill_metadata is not None + for chunk in prefill_metadata.chunks: + # Reuse the workspace to gather the FP8 K + scale for this chunk. + workspace_manager = current_workspace_manager() + values_spec, scales_spec = _gather_workspace_shapes_fp8( + total_seq_lens, head_dim, fp8_dtype + ) + k_quant_full, k_scale_full = ( + workspace_manager.get_simultaneous(values_spec, scales_spec) + ) + k_quant = k_quant_full[: chunk.total_seq_lens] + k_scale = k_scale_full[: chunk.total_seq_lens] + + ops.cp_gather_indexer_k_quant_cache( + kv_cache, + k_quant, + k_scale, + chunk.block_table, + chunk.cu_seq_lens, + ) + + q_slice = q_fp8[chunk.token_start : chunk.token_end] + w_slice = weights[chunk.token_start : chunk.token_end] + k_scale_f32 = k_scale.view(torch.float32).squeeze(-1) + + logits = _mqa_logits( + q_slice, + k_quant, + k_scale_f32, + w_slice, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + ) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[ + chunk.token_start : chunk.token_end, :topk_tokens + ] + torch.ops._C.top_k_per_row_prefill( + logits, + chunk.cu_seqlen_ks, + chunk.cu_seqlen_ke, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if has_decode: + decode_metadata = layer_attn_metadata.decode + assert decode_metadata is not None + + # The kv_cache stored shape is (num_blocks, block_size, head_dim+pad); + # paged-mqa-logits expects an extra "n_head" singleton dim. + kv_cache_4d = kv_cache.unsqueeze(-2) + + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + padded_q_fp8_decode_tokens = pack_seq_triton( + q_fp8[:num_decode_tokens], decode_lens + ) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape( + decode_lens.shape[0], -1, *q_fp8.shape[1:] + ) + + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + + # Slow-but-correct paged compute. Future Triton kernel TODO: walk the + # block_table on-device to avoid the per-batch python loop and the + # per-block (H, next_n, block_size) intermediate. + logits = _mqa_logits_paged_torch( + padded_q_fp8_decode_tokens, + kv_cache_4d, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + max_model_len, + head_dim, + ) + + num_rows = logits.shape[0] + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + torch.ops._C.top_k_per_row_decode( + logits, + next_n, + decode_metadata.seq_lens, + topk_indices, + num_rows, + logits.stride(0), + logits.stride(1), + topk_tokens, + ) + + if decode_metadata.requires_padding: + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[ + : topk_indices.shape[0], : topk_indices.shape[-1] + ] = topk_indices + + return topk_indices_buffer + + +def rocm_sparse_attn_indexer_no_insert_fake( + hidden_states: torch.Tensor, + k_cache_prefix: LayerNameType, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: str | None, + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: torch.Tensor | None, +) -> torch.Tensor: + # Mirror rocm_aiter_sparse_attn_indexer_fake's profile-run estimate so + # vllm's memory profiler accounts for the gather workspace. + fp8_dtype = current_platform.fp8_dtype() + _flattened_kv = torch.empty( + [total_seq_lens, head_dim + 4], + device=q_fp8.device, + dtype=torch.uint8, + ) + _ = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous() + _ = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + if topk_indices_buffer is None: + return torch.empty( + (hidden_states.shape[0], topk_tokens), + dtype=torch.int32, + device=q_fp8.device, + ) + return topk_indices_buffer + + +# Register as a vllm custom op so vllm's compile / dispatch infrastructure +# treats it the same as the existing sparse_attn_indexer ops. +direct_register_custom_op( + op_name="rocm_sparse_attn_indexer_no_insert", + op_func=rocm_sparse_attn_indexer_no_insert, + mutates_args=["topk_indices_buffer"], + fake_impl=rocm_sparse_attn_indexer_no_insert_fake, + dispatch_key=current_platform.dispatch_key, +)