diff --git a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py index 98d637940bca..ab1e2f1181e4 100644 --- a/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/gpt_oss_triton_kernels_moe.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import json +import os +from contextlib import contextmanager +from pathlib import Path + import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk @@ -33,16 +38,159 @@ logger = init_logger(__name__) +_MOE_SHAPE_DUMP_COUNT = 0 +_MOE_SHAPE_DUMP_WARNED = False +_ogs_opt_flags = None + + +def _env_int(name: str, default: int) -> int: + value = os.environ.get(name) + if value is None or value == "": + return default + return int(value) + + +def _dsv4_flash_rocm_ogs_constraints( + *, + m: int, + k: int, + n: int, + e: int, + topk: int, + activation: MoEActivation, +) -> dict[str, int] | None: + if not current_platform.is_rocm(): + return None + if os.environ.get("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_TUNED", "1") == "0": + return None + + # These are the high-throughput DeepSeek-V4-Flash routed-expert shapes on # MI300X. The default OGS tile is 128x256x128; measured serving-shaped # microbenchmarks are faster with a smaller M tile on CDNA3, including the # prefill/ramp shapes seen in the fixed 512/512 benchmark. + if ( m >= 512 + and k == 4096 + and n == 4096 + and e == 128 + and topk == 6 + and activation == MoEActivation.SILU + ): default_block_m = 32 if m < 1024 else 64 + constraints = { + "block_m": _env_int( "VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_M", default_block_m + ), + "block_n": _env_int("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_N", 128), + "block_k": _env_int("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_K", 128), + } if m >= 1024: constraints["epilogue_subtile"] = _env_int( "VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_EPILOGUE_SUBTILE", 16 ) + return constraints + return None + + +@contextmanager +def _temporary_ogs_constraints(constraints: dict[str, int] | None): + if not constraints or _ogs_opt_flags is None: + yield + return + + previous = getattr(_ogs_opt_flags, "_opt_flags_constraints", {}).copy() + try: + _ogs_opt_flags.reset_opt_flags_constraints() + if previous: + _ogs_opt_flags.update_opt_flags_constraints(previous) + _ogs_opt_flags.update_opt_flags_constraints(constraints) + yield + finally: + _ogs_opt_flags.reset_opt_flags_constraints() + if previous: + _ogs_opt_flags.update_opt_flags_constraints(previous) + + +def _maybe_dump_dsv4_moe_shape( + *, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + topk: int, + activation: MoEActivation, + global_num_experts: int, +) -> None: + dump_dir = os.environ.get("DSV4_MOE_SHAPE_DUMP_DIR") + if not dump_dir: + return + + # Host copies inside graph capture are illegal on ROCm and would also + # perturb the graph. Shape collection is an eager/profiling-only mode. + if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): + return + + global _MOE_SHAPE_DUMP_COUNT + limit = int(os.environ.get("DSV4_MOE_SHAPE_DUMP_LIMIT", "0") or "0") + if limit > 0 and _MOE_SHAPE_DUMP_COUNT >= limit: + return + + stride = max(1, int(os.environ.get("DSV4_MOE_SHAPE_DUMP_STRIDE", "1") or "1")) + _MOE_SHAPE_DUMP_COUNT += 1 + if (_MOE_SHAPE_DUMP_COUNT - 1) % stride != 0: + return + + min_m = int(os.environ.get("DSV4_MOE_SHAPE_DUMP_MIN_M", "0") or "0") + M, K = hidden_states.shape + if M < min_m: + return + + try: + local_num_experts = int(w1.shape[0]) + valid_topk = topk_ids[topk_ids >= 0].reshape(-1) + hist = torch.bincount( + valid_topk.to(torch.int64), minlength=local_num_experts + )[:local_num_experts].cpu() + nonzero = hist[hist > 0] + if nonzero.numel() == 0: + p90_nonzero = 0 + hist_max = 0 + else: + p90_nonzero = int( + torch.quantile(nonzero.float(), 0.9).round().item() + ) + hist_max = int(nonzero.max().item()) + + rec = { + "pid": os.getpid(), + "rank": os.environ.get("RANK"), + "local_rank": os.environ.get("LOCAL_RANK"), + "count": _MOE_SHAPE_DUMP_COUNT, + "activation": activation.name, + "M": int(M), + "K": int(K), + "topk": int(topk), + "global_num_experts": int(global_num_experts), + "local_num_experts": local_num_experts, + "w1_shape": list(w1.shape), + "w2_shape": list(w2.shape), + "hist_sum": int(hist.sum().item()), + "hist_nonzero": int(nonzero.numel()), + "hist_max": hist_max, + "p90_nonzero": p90_nonzero, + "hist": [int(x) for x in hist.tolist()], + } + path = Path(dump_dir) + path.mkdir(parents=True, exist_ok=True) + filename = f"moe_shapes_rank{rec['rank'] or 'x'}_pid{os.getpid()}.jsonl" + with (path / filename).open("a") as f: + f.write(json.dumps(rec, separators=(",", ":")) + "\n") + except Exception as e: + global _MOE_SHAPE_DUMP_WARNED + if not _MOE_SHAPE_DUMP_WARNED: + _MOE_SHAPE_DUMP_WARNED = True + logger.warning("Failed to dump DeepSeek V4 MoE shape: %s", e) + def _triton_kernel_moe_supports_current_device() -> bool: # Shared device gate for the OAI Triton MoE expert classes. @@ -245,6 +393,7 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix): if has_triton_kernels(): try: import triton_kernels.swiglu + import triton_kernels.matmul_ogs_details.opt_flags as _ogs_opt_flags from triton_kernels.matmul_ogs import ( FnSpecs, FusedActivation, @@ -884,6 +1033,16 @@ def apply( if global_num_experts == -1: global_num_experts = E + _maybe_dump_dsv4_moe_shape( + hidden_states=hidden_states, + w1=w1, + w2=w2, + topk_ids=topk_ids, + topk=topk, + activation=activation, + global_num_experts=global_num_experts, + ) + # Note that the output tensor might be in workspace13 intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N)) intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K)) @@ -892,18 +1051,94 @@ def apply( gammas = routing_data.gate_scal if routing_data else None + ogs_constraints = _dsv4_flash_rocm_ogs_constraints( + m=M, k=K, n=N, e=E, topk=topk, activation=activation ) + with _temporary_ogs_constraints(ogs_constraints): + matmul_ogs( + hidden_states, + w1, + quant_config.w1_bias, + routing_data, + gather_indx=gather_indx, + precision_config=quant_config.w1_precision, + gammas=gammas if apply_router_weight_on_input else None, + fused_activation=None, + y=intermediate_cache1, + ) sorted_token_ids_lora = None expert_ids_lora = None num_tokens_post_padded_lora = None token_lora_mapping = None lora_context = self._lora_context + if lora_context is None: + # W1 writes in expert-sorted order. The old no-LoRA path gathered + # back to token-topk order for activation, then gathered back to + # expert-sorted order for W2; those two gathers cancel. + self.activation( + activation, + intermediate_cache2, + intermediate_cache1.view(-1, N), + ) + with _temporary_ogs_constraints(ogs_constraints): + matmul_ogs( + intermediate_cache2, + w2, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=quant_config.w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=output, + ) + return + + # w13 LoRA: gather the activation input from expert-sorted + # intermediate_cache1, then add the LoRA delta in-place on that copy + # before passing it to activation — exactly mirroring the old + # decorator approach which modified the gathered tensor in-place. + act_input = intermediate_cache1.view(-1, N)[gather_indx.dst_indx] + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + ) = self.apply_w13_lora( + lora_context, + y=act_input, + x=hidden_states, + topk_ids=global_topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=M, + top_k_num=topk, + ) + + self.activation( + activation, + intermediate_cache2, + act_input, + ) + # matmul_ogs grouped reduction fuses sum across multiple experts: # y[dst_indx // n_expts_act, :] += x # Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum. routing_data.n_expts_act = 1 + with _temporary_ogs_constraints(ogs_constraints): + matmul_ogs( + intermediate_cache2[gather_indx.src_indx], + w2, + quant_config.w2_bias, + routing_data, + scatter_indx=scatter_indx, + precision_config=quant_config.w2_precision, + gammas=None if apply_router_weight_on_input else gammas, + y=intermediate_cache3, + ) # w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is # in token-topk order, matching the (M, topk, K) layout add_lora_w2 expects. diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index 2870ec9a15c0..274c4bb53fdd 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -173,6 +173,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: cu_seq_lens: torch.Tensor token_to_seq: torch.Tensor total_seq_lens: int + max_seq_len: int token_start: int token_end: int num_reqs: int @@ -192,6 +193,7 @@ class DeepSeekV32IndexerDecodeMetadata: # - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1 # Both fp8_fp4_paged_mqa_logits and the topk kernels accept both shapes. seq_lens: torch.Tensor + max_seq_len: int decode_lens: torch.Tensor requires_padding: bool schedule_metadata: torch.Tensor @@ -553,6 +555,7 @@ def build( decode_metadata = None if num_decodes > 0: + assert common_attn_metadata.seq_lens_cpu_upper_bound is not None torch.diff( common_attn_metadata.query_start_loc[: num_decodes + 1], out=self.decode_lens_buffer[:num_decodes], @@ -563,6 +566,7 @@ def build( ) seq_lens = common_attn_metadata.seq_lens[:num_decodes] + seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound[:num_decodes] block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] max_decode_len = int(decode_lens_cpu.max().item()) @@ -587,6 +591,7 @@ def build( # For DeepseekV4 (compress_ratio > 1), the indexer KV cache stores # compressed tokens. Convert uncompressed seq_lens to compressed. if self.compress_ratio > 1: + seq_lens_cpu = seq_lens_cpu // self.compress_ratio # True iff seq_lens aliases decode_seq_lens_buffer (flatten or # native wrote it); False iff it aliases common_attn_metadata. seq_lens_is_local_view = (use_native and next_n > 1) or ( @@ -619,6 +624,7 @@ def build( decode_metadata = DeepSeekV32IndexerDecodeMetadata( block_table=block_table, seq_lens=seq_lens, + max_seq_len=int(seq_lens_cpu.max().item()), decode_lens=decode_lens, requires_padding=requires_padding, schedule_metadata=self.scheduler_metadata_buffer, @@ -655,6 +661,7 @@ def build_prefill_chunk_metadata( total_seq_lens = compressed_seq_lens_cpu[start_idx:end_idx].sum().item() if total_seq_lens == 0: return None + max_seq_len = int(compressed_seq_lens_cpu[start_idx:end_idx].max().item()) num_reqs = end_idx - start_idx device = block_table.device @@ -710,6 +717,7 @@ def build_prefill_chunk_metadata( cu_seq_lens=cu_seq_lens, token_to_seq=token_to_seq, total_seq_lens=total_seq_lens, + max_seq_len=max_seq_len, block_table=block_table[start_idx:end_idx], token_start=token_start, token_end=token_end, diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index ed2147f7fe5d..0ca5f72b62f4 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -2,7 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import importlib +import json import math +import os import sys from importlib.util import find_spec @@ -24,10 +26,63 @@ _ON_GFX950 = False +def _env_int(name: str, default: int) -> int: + value = os.getenv(name) + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + +def _env_int_or_none(name: str) -> int | None: + value = os.getenv(name) + if value is None: + return None + try: + return int(value) + except ValueError: + return None + + def _sparse_indexer_debug_enabled() -> bool: return os.getenv("DSV4_SPARSE_INDEXER_DEBUG", "0") == "1" +def _sparse_prefill_mem_metrics_enabled() -> bool: + return os.getenv("DSV4_SPARSE_PREFILL_MEM_METRICS", "0") == "1" + + +def _sparse_prefill_mem_snapshot(device: torch.device) -> dict[str, int]: + free, total = torch.cuda.mem_get_info(device) + return { + "free": int(free), + "total": int(total), + "allocated": int(torch.cuda.memory_allocated(device)), + "reserved": int(torch.cuda.memory_reserved(device)), + } + + +def _log_sparse_prefill_mem_metrics( + event: str, + device: torch.device, + **kwargs: int | str | float, +) -> None: + if not _sparse_prefill_mem_metrics_enabled(): + return + record: dict[str, int | str | float] = { + "event": event, + **_sparse_prefill_mem_snapshot(device), + **kwargs, + } + print( + "DSV4_SPARSE_PREFILL_MEM " + json.dumps(record, sort_keys=True), + file=sys.stderr, + flush=True, + ) + + def _log_sparse_indexer_debug(message: str, device: torch.device) -> None: if not _sparse_indexer_debug_enabled(): return @@ -40,6 +95,53 @@ def _log_sparse_indexer_debug(message: str, device: torch.device) -> None: ) +def _select_sparse_decode_config( + num_queries: int, + head_dim: int, + extra_indices: torch.Tensor, +) -> tuple[int, int, int]: + block_h_override = _env_int_or_none("DSV4_SPARSE_ATTN_DECODE_BLOCK_H") + block_k_override = _env_int_or_none("DSV4_SPARSE_ATTN_DECODE_BLOCK_K") + num_warps_override = _env_int_or_none("DSV4_SPARSE_ATTN_DECODE_NUM_WARPS") + + block_h = 16 + block_k = 16 if head_dim >= 256 else 32 + num_warps = 4 + + extra_per_query = ( + extra_indices.numel() // num_queries if num_queries > 0 else 0 + ) + if extra_per_query <= 8: + if num_queries >= 256: + block_h, block_k = 64, 16 + if extra_per_query > 0: + num_warps = 8 + elif num_queries >= 80: + block_h, block_k = 32, 16 + elif num_queries == 32: + block_h, block_k = 16, 16 + else: + block_h, block_k = 16, 32 + else: + if num_queries < 32: + block_h, block_k = 4, 64 + elif num_queries >= 256: + block_h, block_k = 64, 16 + num_warps = 8 + elif num_queries >= 80: + block_h, block_k = 32, 16 + else: + block_h, block_k = 16, 32 + + if block_h_override is not None: + block_h = block_h_override + if block_k_override is not None: + block_k = block_k_override + if num_warps_override is not None: + num_warps = num_warps_override + return block_h, block_k, num_warps + + @triton.jit def _indexer_k_quant_and_cache_kernel( k_ptr, # [num_tokens, head_dim] @@ -627,6 +729,84 @@ def _sparse_prefill_logits_chunk_size() -> int: return max(1, _env_int("DSV4_SPARSE_PREFILL_LOGITS_CHUNK_SIZE", 512)) +@triton.jit +def _fill_full_window_topk_prefill_kernel( + topk_out_ptr, + topk_out_stride, + cu_seqlen_ks_ptr, + cu_seqlen_ke_ptr, + topk_tokens: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_id = tl.program_id(0) + block_id = tl.program_id(1) + offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + row_start = tl.load(cu_seqlen_ks_ptr + row_id) + row_end = tl.load(cu_seqlen_ke_ptr + row_id) + row_len = row_end - row_start + values = tl.where(offsets < row_len, offsets, -1) + tl.store( + topk_out_ptr + row_id * topk_out_stride + offsets, + values, + mask=offsets < topk_tokens, + ) + + +def _fill_full_window_topk_prefill( + topk_out: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk_tokens: int, +) -> None: + block = 256 + grid = (topk_out.shape[0], triton.cdiv(topk_tokens, block)) + _fill_full_window_topk_prefill_kernel[grid]( + topk_out, + topk_out.stride(0), + cu_seqlen_ks, + cu_seqlen_ke, + topk_tokens, + BLOCK_SIZE=block, + ) + + +@triton.jit +def _fill_full_window_topk_decode_kernel( + topk_out_ptr, + topk_out_stride, + seq_lens_ptr, + topk_tokens: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row_id = tl.program_id(0) + block_id = tl.program_id(1) + offsets = block_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + seq_len = tl.load(seq_lens_ptr + row_id) + values = tl.where(offsets < seq_len, offsets, -1) + tl.store( + topk_out_ptr + row_id * topk_out_stride + offsets, + values, + mask=offsets < topk_tokens, + ) + + +def _fill_full_window_topk_decode( + topk_out: torch.Tensor, + seq_lens: torch.Tensor, + topk_tokens: int, +) -> None: + block = 256 + seq_lens = seq_lens.reshape(-1) + grid = (topk_out.shape[0], triton.cdiv(topk_tokens, block)) + _fill_full_window_topk_decode_kernel[grid]( + topk_out, + topk_out.stride(0), + seq_lens, + topk_tokens, + BLOCK_SIZE=block, + ) + + def _topk_indices_prefill( logits: torch.Tensor, topk_tokens: int, @@ -785,6 +965,15 @@ def rocm_aiter_sparse_attn_indexer( prefill_metadata = layer_attn_metadata.prefill assert prefill_metadata is not None prefill_chunk_size = _sparse_prefill_chunk_size() + for chunk_idx, chunk in enumerate(prefill_metadata.chunks): + _log_sparse_prefill_mem_metrics( + "before_gather", + device, + chunk=chunk_idx, + M=chunk.token_end - chunk.token_start, + N=chunk.total_seq_lens, + topk=topk_tokens, + ) k_fp8 = torch.empty( [chunk.total_seq_lens, head_dim], device=device, @@ -812,6 +1001,14 @@ def rocm_aiter_sparse_attn_indexer( chunk.cu_seq_lens, token_to_seq=chunk.token_to_seq, ) + _log_sparse_prefill_mem_metrics( + "after_gather", + device, + chunk=chunk_idx, + M=chunk.token_end - chunk.token_start, + N=chunk.total_seq_lens, + topk=topk_tokens, + ) chunk_tokens = chunk.token_end - chunk.token_start chunk_max_seq_len = getattr(chunk, "max_seq_len", chunk.total_seq_lens) @@ -830,7 +1027,70 @@ def rocm_aiter_sparse_attn_indexer( topk_indices = topk_indices_buffer[ token_start:token_end, :topk_tokens ] + _log_sparse_prefill_mem_metrics( + "before_topk_select", + device, + chunk=chunk_idx, + row_start=row_start, + rows=row_end - row_start, + N=chunk.total_seq_lens, + topk=topk_tokens, ) + if chunk_max_seq_len <= topk_tokens: + _fill_full_window_topk_prefill( + topk_indices, + cu_seqlen_ks, + cu_seqlen_ke, + topk_tokens, + ) + _log_sparse_prefill_mem_metrics( + "after_full_window_fill", + device, + chunk=chunk_idx, + row_start=row_start, + rows=row_end - row_start, + N=chunk.total_seq_lens, + topk=topk_tokens, + ) + else: + logits = rocm_fp8_mqa_logits( + q_fp8[token_start:token_end], + (k_fp8, k_scale.view(torch.float32)), + weights[token_start:token_end], + cu_seqlen_ks, + cu_seqlen_ke, + ) + _log_sparse_indexer_debug( + "after_logits " + f"M={row_end - row_start} " + f"N={chunk.total_seq_lens}", + device, + ) + _log_sparse_prefill_mem_metrics( + "after_logits", + device, + chunk=chunk_idx, + row_start=row_start, + rows=row_end - row_start, + N=chunk.total_seq_lens, + topk=topk_tokens, + ) + _topk_indices_prefill( + logits, + topk_tokens, + topk_indices, + cu_seqlen_ks, + cu_seqlen_ke, + ) + _log_sparse_prefill_mem_metrics( + "after_topk", + device, + chunk=chunk_idx, + row_start=row_start, + rows=row_end - row_start, + N=chunk.total_seq_lens, + topk=topk_tokens, + ) _log_sparse_indexer_debug( "after_topk " f"M={row_end - row_start} " @@ -862,6 +1122,29 @@ def rocm_aiter_sparse_attn_indexer( next_n = padded_q_fp8_decode_tokens.shape[1] assert batch_size == decode_metadata.seq_lens.shape[0] num_padded_tokens = batch_size * next_n + topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens] + + max_seq_len = getattr(decode_metadata, "max_seq_len", None) + if max_seq_len is None: + max_seq_len = int(decode_metadata.seq_lens.max().item()) + + if max_seq_len <= topk_tokens: + _fill_full_window_topk_decode( + topk_indices, + decode_metadata.seq_lens, + topk_tokens, + ) + if decode_metadata.requires_padding: + # if padded, we need to unpack + # the topk indices removing padded tokens + topk_indices = unpack_seq_triton( + topk_indices.reshape(batch_size, next_n, topk_indices.shape[-1]), + decode_lens, + ) + topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = ( + topk_indices + ) + return topk_indices_buffer logits = rocm_fp8_paged_mqa_logits( padded_q_fp8_decode_tokens, @@ -986,10 +1269,41 @@ def rocm_inv_rope_einsum( hidden_dim = o_ref.shape[-1] if hasattr(wo_a, "weight_scale_inv"): + cache_key = ( + n_local_groups, o_lora_rank, hidden_dim, + wo_a.weight.data_ptr(), + wo_a.weight_scale_inv.data_ptr(), ) + cached_key = getattr(wo_a, "_vllm_rocm_bf16_weight_key", None) + wo_a_weight = getattr(wo_a, "_vllm_rocm_bf16_weight_cache", None) + if cached_key != cache_key or wo_a_weight is None: + weight_fp32 = wo_a.weight.view( + n_local_groups, o_lora_rank, hidden_dim + ).to(torch.float32) + wo_a_scale = _expand_2d_block_scales( + wo_a.weight_scale_inv.view( + n_local_groups, -1, wo_a.weight_scale_inv.shape[-1] + ), + o_lora_rank, + hidden_dim, + ) + wo_a_weight = (weight_fp32 * wo_a_scale).to(torch.bfloat16) + wo_a._vllm_rocm_bf16_weight_cache = wo_a_weight + wo_a._vllm_rocm_bf16_weight_key = cache_key else: + cache_key = (n_local_groups, o_lora_rank, hidden_dim, wo_a.weight.data_ptr()) + cached_key = getattr(wo_a, "_vllm_rocm_bf16_weight_key", None) + wo_a_weight = getattr(wo_a, "_vllm_rocm_bf16_weight_cache", None) + if cached_key != cache_key or wo_a_weight is None: + wo_a_weight = wo_a.weight.view( + n_local_groups, o_lora_rank, hidden_dim + ).to(torch.bfloat16) + wo_a._vllm_rocm_bf16_weight_cache = wo_a_weight + wo_a._vllm_rocm_bf16_weight_key = cache_key + + assert wo_a_weight is not None return torch.einsum("tgd,grd->tgr", o_ref, wo_a_weight) @@ -1015,6 +1329,125 @@ def _validate_dsv4_sparse_dims( ) +_DSV4_SPARSE_DECODE_SHAPE_CALLS = 0 + + +def _shape_dump_limit() -> int: + return _env_int("DSV4_SPARSE_DECODE_SHAPE_DUMP_LIMIT", 0) + + +def _shape_dump_stride() -> int: + return max(1, _env_int("DSV4_SPARSE_DECODE_SHAPE_DUMP_STRIDE", 1)) + + +def _tensor_shape(x: torch.Tensor | None) -> list[int] | None: + return None if x is None else list(x.shape) + + +def _tensor_stride(x: torch.Tensor | None) -> list[int] | None: + return None if x is None else list(x.stride()) + + +def _length_summary(x: torch.Tensor | None) -> dict[str, object] | None: + if x is None: + return None + flat = x.detach().to("cpu", dtype=torch.int64).reshape(-1) + if flat.numel() == 0: + return { + "numel": 0, + "sum": 0, + "min": 0, + "max": 0, + "mean": 0.0, + "hist": [], + } + values, counts = torch.unique(flat, sorted=True, return_counts=True) + return { + "numel": int(flat.numel()), + "sum": int(flat.sum().item()), + "min": int(flat.min().item()), + "max": int(flat.max().item()), + "mean": float(flat.float().mean().item()), + "hist": [ + [int(v.item()), int(c.item())] for v, c in zip(values, counts) + ], + } + + +def _indptr_length_summary(indptr: torch.Tensor | None) -> dict[str, object] | None: + if indptr is None: + return None + return _length_summary(indptr[1:] - indptr[:-1]) + + +def _maybe_dump_sparse_decode_shape( + *, + q: torch.Tensor, + kv_cache: torch.Tensor | None, + swa_k_cache: torch.Tensor, + swa_only: bool, + topk_indices: torch.Tensor | None, + topk_lens: torch.Tensor | None, + swa_indices: torch.Tensor, + swa_lens: torch.Tensor, + swa_ragged_indices: torch.Tensor | None, + swa_ragged_indptr: torch.Tensor | None, + topk_ragged_indices: torch.Tensor | None, + topk_ragged_indptr: torch.Tensor | None, + output: torch.Tensor, +) -> None: + dump_dir = os.getenv("DSV4_SPARSE_DECODE_SHAPE_DUMP_DIR", "") + if not dump_dir: + return + if torch.cuda.is_current_stream_capturing(): + return + + global _DSV4_SPARSE_DECODE_SHAPE_CALLS + call_idx = _DSV4_SPARSE_DECODE_SHAPE_CALLS + _DSV4_SPARSE_DECODE_SHAPE_CALLS += 1 + + limit = _shape_dump_limit() + stride = _shape_dump_stride() + if limit and call_idx >= limit: + return + if call_idx % stride != 0: + return + + os.makedirs(dump_dir, exist_ok=True) + main_ragged_lens = _indptr_length_summary(swa_ragged_indptr) + extra_ragged_lens = _indptr_length_summary(topk_ragged_indptr) + record = { + "call_idx": call_idx, + "pid": os.getpid(), + "rank": os.getenv("RANK"), + "local_rank": os.getenv("LOCAL_RANK"), + "q_shape": _tensor_shape(q), + "q_stride": _tensor_stride(q), + "q_dtype": str(q.dtype), + "output_shape": _tensor_shape(output), + "output_stride": _tensor_stride(output), + "output_dtype": str(output.dtype), + "swa_only": bool(swa_only), + "swa_cache_shape": _tensor_shape(swa_k_cache), + "swa_cache_stride": _tensor_stride(swa_k_cache), + "kv_cache_shape": _tensor_shape(kv_cache), + "kv_cache_stride": _tensor_stride(kv_cache), + "swa_indices_shape": _tensor_shape(swa_indices), + "topk_indices_shape": _tensor_shape(topk_indices), + "swa_ragged_indices_shape": _tensor_shape(swa_ragged_indices), + "topk_ragged_indices_shape": _tensor_shape(topk_ragged_indices), + "swa_lens": _length_summary(swa_lens), + "topk_lens": _length_summary(topk_lens), + "swa_ragged_lens": main_ragged_lens, + "topk_ragged_lens": extra_ragged_lens, + "effective_swa_lens": main_ragged_lens or _length_summary(swa_lens), + "effective_topk_lens": extra_ragged_lens or _length_summary(topk_lens), + } + path = os.path.join(dump_dir, f"sparse_decode_shapes_{os.getpid()}.jsonl") + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, sort_keys=True) + "\n") + + @triton.jit def _pack_dense_prefix_to_ragged_kernel( indices_ptr, @@ -1500,6 +1933,7 @@ def _rocm_sparse_attn_prefill_ragged_triton( attn_sink: torch.Tensor | None, nope_head_dim: int, rope_head_dim: int, + out: torch.Tensor | None = None, ) -> torch.Tensor: assert q.ndim == 3, f"expected q=[sq,h,d], got {q.shape}" assert kv.ndim == 2, f"expected kv=[skv,d], got {kv.shape}" @@ -1529,6 +1963,11 @@ def _rocm_sparse_attn_prefill_ragged_triton( block_h = 16 block_d = triton.next_power_of_2(head_dim) block_k = 16 if head_dim >= 256 else 32 + if out is None: + out = torch.empty_like(q, dtype=torch.bfloat16) + else: + assert out.shape == q.shape, f"expected out shape {q.shape}, got {out.shape}" + assert out.dtype == torch.bfloat16, f"expected bf16 out, got {out.dtype}" _sparse_attn_prefill_ragged_kernel[(num_queries, triton.cdiv(num_heads, block_h))]( q, kv, @@ -1598,6 +2037,7 @@ def _rocm_sparse_attn_decode_ragged_triton( extra_cache: torch.Tensor | None = None, extra_indices: torch.Tensor | None = None, extra_indptr: torch.Tensor | None = None, + out: torch.Tensor | None = None, ) -> torch.Tensor: assert q.ndim == 3, f"expected q=[b,h,d], got {q.shape}" assert main_cache.ndim == 3, ( @@ -1658,6 +2098,16 @@ def _rocm_sparse_attn_decode_ragged_triton( extra_indices = torch.empty(0, device=q.device, dtype=torch.int32) extra_indptr = torch.zeros(num_queries + 1, device=q.device, dtype=torch.int32) + block_h, block_k, num_warps = _select_sparse_decode_config( + num_queries, + head_dim, + extra_indices, + ) + if out is None: + out = torch.empty_like(q, dtype=torch.bfloat16) + else: + assert out.shape == q.shape, f"expected out shape {q.shape}, got {out.shape}" + assert out.dtype == torch.bfloat16, f"expected bf16 out, got {out.dtype}" _sparse_attn_decode_ragged_kernel[(num_queries, triton.cdiv(num_heads, block_h))]( q, main_cache, @@ -1688,6 +2138,7 @@ def _rocm_sparse_attn_decode_ragged_triton( IS_FNUZ=current_platform.is_fp8_fnuz(), BLOCK_H=block_h, BLOCK_K=block_k, + num_warps=num_warps, ) return out @@ -1708,6 +2159,7 @@ def _rocm_sparse_attn_decode_triton( main_ragged_indptr: torch.Tensor | None = None, extra_ragged_indices: torch.Tensor | None = None, extra_ragged_indptr: torch.Tensor | None = None, + out: torch.Tensor | None = None, ) -> torch.Tensor: if main_ragged_indices is None or main_ragged_indptr is None: main_ragged_indices, main_ragged_indptr = build_ragged_indices_from_dense( @@ -1743,6 +2195,7 @@ def _rocm_sparse_attn_decode_triton( extra_cache=extra_cache, extra_indices=extra_ragged_indices, extra_indptr=extra_ragged_indptr, + out=out, ) @@ -1771,6 +2224,7 @@ def rocm_sparse_attn_prefill( ) if ragged_indices is not None and ragged_indptr is not None: + direct_out = output if output.dtype == torch.bfloat16 else None output_chunk = _rocm_sparse_attn_prefill_ragged_triton( q=q, kv=kv.squeeze(1), @@ -1780,6 +2234,7 @@ def rocm_sparse_attn_prefill( attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, + out=direct_out, ) else: indices_2d = indices.reshape(indices.shape[0], -1) @@ -1793,6 +2248,8 @@ def rocm_sparse_attn_prefill( rope_head_dim=rope_head_dim, topk_length=topk_length, ) + if output_chunk is not output: + output.copy_(output_chunk.to(output.dtype)) def rocm_sparse_attn_decode( @@ -1843,6 +2300,22 @@ def rocm_sparse_attn_decode( if topk_indices is not None: extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) + _maybe_dump_sparse_decode_shape( + q=q, + kv_cache=kv_cache, + swa_k_cache=swa_k_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + swa_ragged_indices=swa_ragged_indices, + swa_ragged_indptr=swa_ragged_indptr, + topk_ragged_indices=topk_ragged_indices, + topk_ragged_indptr=topk_ragged_indptr, + output=output, + ) + attn_out = _rocm_sparse_attn_decode_triton( q=q, main_cache=swa_k_cache, @@ -1859,4 +2332,7 @@ def rocm_sparse_attn_decode( main_ragged_indptr=swa_ragged_indptr, extra_ragged_indices=topk_ragged_indices, extra_ragged_indptr=topk_ragged_indptr, + out=output if output.dtype == torch.bfloat16 else None, ) + if attn_out is not output: + output.copy_(attn_out.to(output.dtype))