Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import json
import os
from contextlib import contextmanager
from pathlib import Path

import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
Expand Down Expand Up @@ -33,16 +38,159 @@

logger = init_logger(__name__)

_MOE_SHAPE_DUMP_COUNT = 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: Module-level global counter without explicit synchronization.

Why it matters: While Python's GIL protects simple integer increment operations, this counter could exhibit unexpected behavior in multi-process scenarios (each process gets its own copy) or if the module is reloaded. For the intended profiling use case, this is acceptable.

Suggested fix: If cross-process coordination is ever needed, consider using a file-based counter or multiprocessing.Value. For now, a comment documenting the expected single-process-per-GPU pattern would suffice.

_MOE_SHAPE_DUMP_WARNED = False
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Boolean flag without atomic protection.

Why it matters: In a multi-threaded scenario, multiple threads could theoretically pass the if not _MOE_SHAPE_DUMP_WARNED check simultaneously before either sets it to True, resulting in duplicate warnings. The GIL makes this unlikely in practice.

Suggested fix: Consider using a module-level lock or accepting that duplicate warnings are harmless for this debug feature.

_ogs_opt_flags = None


def _env_int(name: str, default: int) -> int:
value = os.environ.get(name)
if value is None or value == "":
return default
return int(value)


def _dsv4_flash_rocm_ogs_constraints(
*,
m: int,
k: int,
n: int,
e: int,
topk: int,
activation: MoEActivation,
) -> dict[str, int] | None:
if not current_platform.is_rocm():
return None
if os.environ.get("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_TUNED", "1") == "0":
return None

# These are the high-throughput DeepSeek-V4-Flash routed-expert shapes on
# MI300X. The default OGS tile is 128x256x128; measured serving-shaped
# microbenchmarks are faster with a smaller M tile on CDNA3, including the
# prefill/ramp shapes seen in the fixed 512/512 benchmark.
if (
m >= 512
and k == 4096
and n == 4096
and e == 128
and topk == 6
and activation == MoEActivation.SILU
):
default_block_m = 32 if m < 1024 else 64
constraints = {
"block_m": _env_int(
"VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_M", default_block_m
),
"block_n": _env_int("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_N", 128),
"block_k": _env_int("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_K", 128),
}
if m >= 1024:
constraints["epilogue_subtile"] = _env_int(
"VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_EPILOGUE_SUBTILE", 16
)
return constraints
return None


@contextmanager
def _temporary_ogs_constraints(constraints: dict[str, int] | None):
if not constraints or _ogs_opt_flags is None:
yield
return

previous = getattr(_ogs_opt_flags, "_opt_flags_constraints", {}).copy()
try:
_ogs_opt_flags.reset_opt_flags_constraints()
if previous:
_ogs_opt_flags.update_opt_flags_constraints(previous)
_ogs_opt_flags.update_opt_flags_constraints(constraints)
yield
finally:
_ogs_opt_flags.reset_opt_flags_constraints()
if previous:
_ogs_opt_flags.update_opt_flags_constraints(previous)


def _maybe_dump_dsv4_moe_shape(
*,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk: int,
activation: MoEActivation,
global_num_experts: int,
) -> None:
dump_dir = os.environ.get("DSV4_MOE_SHAPE_DUMP_DIR")
if not dump_dir:
return

# Host copies inside graph capture are illegal on ROCm and would also
# perturb the graph. Shape collection is an eager/profiling-only mode.
if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing():
return

global _MOE_SHAPE_DUMP_COUNT
limit = int(os.environ.get("DSV4_MOE_SHAPE_DUMP_LIMIT", "0") or "0")
if limit > 0 and _MOE_SHAPE_DUMP_COUNT >= limit:
return

stride = max(1, int(os.environ.get("DSV4_MOE_SHAPE_DUMP_STRIDE", "1") or "1"))
_MOE_SHAPE_DUMP_COUNT += 1
if (_MOE_SHAPE_DUMP_COUNT - 1) % stride != 0:
return

min_m = int(os.environ.get("DSV4_MOE_SHAPE_DUMP_MIN_M", "0") or "0")
M, K = hidden_states.shape
if M < min_m:
return

try:
local_num_experts = int(w1.shape[0])
valid_topk = topk_ids[topk_ids >= 0].reshape(-1)
hist = torch.bincount(
valid_topk.to(torch.int64), minlength=local_num_experts
)[:local_num_experts].cpu()
nonzero = hist[hist > 0]
if nonzero.numel() == 0:
p90_nonzero = 0
hist_max = 0
else:
p90_nonzero = int(
torch.quantile(nonzero.float(), 0.9).round().item()
)
hist_max = int(nonzero.max().item())

rec = {
"pid": os.getpid(),
"rank": os.environ.get("RANK"),
"local_rank": os.environ.get("LOCAL_RANK"),
"count": _MOE_SHAPE_DUMP_COUNT,
"activation": activation.name,
"M": int(M),
"K": int(K),
"topk": int(topk),
"global_num_experts": int(global_num_experts),
"local_num_experts": local_num_experts,
"w1_shape": list(w1.shape),
"w2_shape": list(w2.shape),
"hist_sum": int(hist.sum().item()),
"hist_nonzero": int(nonzero.numel()),
"hist_max": hist_max,
"p90_nonzero": p90_nonzero,
"hist": [int(x) for x in hist.tolist()],
}
path = Path(dump_dir)
path.mkdir(parents=True, exist_ok=True)
filename = f"moe_shapes_rank{rec['rank'] or 'x'}_pid{os.getpid()}.jsonl"
with (path / filename).open("a") as f:
f.write(json.dumps(rec, separators=(",", ":")) + "\n")
except Exception as e:
global _MOE_SHAPE_DUMP_WARNED
if not _MOE_SHAPE_DUMP_WARNED:
_MOE_SHAPE_DUMP_WARNED = True
logger.warning("Failed to dump DeepSeek V4 MoE shape: %s", e)


def _triton_kernel_moe_supports_current_device() -> bool:
# Shared device gate for the OAI Triton MoE expert classes.
Expand Down Expand Up @@ -245,6 +393,7 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix):
if has_triton_kernels():
try:
import triton_kernels.swiglu
import triton_kernels.matmul_ogs_details.opt_flags as _ogs_opt_flags
from triton_kernels.matmul_ogs import (
FnSpecs,
FusedActivation,
Expand Down Expand Up @@ -884,6 +1033,16 @@ def apply(
if global_num_experts == -1:
global_num_experts = E

_maybe_dump_dsv4_moe_shape(
hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_ids=topk_ids,
topk=topk,
activation=activation,
global_num_experts=global_num_experts,
)

# Note that the output tensor might be in workspace13
intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N))
intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K))
Expand All @@ -892,18 +1051,94 @@ def apply(

gammas = routing_data.gate_scal if routing_data else None

ogs_constraints = _dsv4_flash_rocm_ogs_constraints(
m=M, k=K, n=N, e=E, topk=topk, activation=activation
)
with _temporary_ogs_constraints(ogs_constraints):
matmul_ogs(
hidden_states,
w1,
quant_config.w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=None,
y=intermediate_cache1,
)

sorted_token_ids_lora = None
expert_ids_lora = None
num_tokens_post_padded_lora = None
token_lora_mapping = None
lora_context = self._lora_context
if lora_context is None:
# W1 writes in expert-sorted order. The old no-LoRA path gathered
# back to token-topk order for activation, then gathered back to
# expert-sorted order for W2; those two gathers cancel.
self.activation(
activation,
intermediate_cache2,
intermediate_cache1.view(-1, N),
)
with _temporary_ogs_constraints(ogs_constraints):
matmul_ogs(
intermediate_cache2,
w2,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=output,
)
return

# w13 LoRA: gather the activation input from expert-sorted
# intermediate_cache1, then add the LoRA delta in-place on that copy
# before passing it to activation — exactly mirroring the old
# decorator approach which modified the gathered tensor in-place.
act_input = intermediate_cache1.view(-1, N)[gather_indx.dst_indx]
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
) = self.apply_w13_lora(
lora_context,
y=act_input,
x=hidden_states,
topk_ids=global_topk_ids,
topk_weights=topk_weights,
expert_map=expert_map,
w1=w1,
w2=w2,
num_tokens=M,
top_k_num=topk,
)

self.activation(
activation,
intermediate_cache2,
act_input,
)

# matmul_ogs grouped reduction fuses sum across multiple experts:
# y[dst_indx // n_expts_act, :] += x
# Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum.
routing_data.n_expts_act = 1

with _temporary_ogs_constraints(ogs_constraints):
matmul_ogs(
intermediate_cache2[gather_indx.src_indx],
w2,
quant_config.w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=quant_config.w2_precision,
gammas=None if apply_router_weight_on_input else gammas,
y=intermediate_cache3,
)

# w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is
# in token-topk order, matching the (M, topk, K) layout add_lora_w2 expects.
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
cu_seq_lens: torch.Tensor
token_to_seq: torch.Tensor
total_seq_lens: int
max_seq_len: int
token_start: int
token_end: int
num_reqs: int
Expand All @@ -192,6 +193,7 @@ class DeepSeekV32IndexerDecodeMetadata:
# - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1
# Both fp8_fp4_paged_mqa_logits and the topk kernels accept both shapes.
seq_lens: torch.Tensor
max_seq_len: int
decode_lens: torch.Tensor
requires_padding: bool
schedule_metadata: torch.Tensor
Expand Down Expand Up @@ -553,6 +555,7 @@ def build(

decode_metadata = None
if num_decodes > 0:
assert common_attn_metadata.seq_lens_cpu_upper_bound is not None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-blocking: Assertion depends on upstream metadata builder.

Why it matters: This assertion (and the similar one at line 513) assumes seq_lens_cpu_upper_bound is always populated by the CommonAttentionMetadata builder. If a future change modifies the metadata construction path, this could fail. The assertion is appropriate here as it documents the precondition.

Suggested fix: No change required. The assertion is good defensive programming. Consider adding a comment referencing where this field is populated (e.g., # Set by CommonAttentionMetadata builder in worker.py).

torch.diff(
common_attn_metadata.query_start_loc[: num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes],
Expand All @@ -563,6 +566,7 @@ def build(
)

seq_lens = common_attn_metadata.seq_lens[:num_decodes]
seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound[:num_decodes]
block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...]

max_decode_len = int(decode_lens_cpu.max().item())
Expand All @@ -587,6 +591,7 @@ def build(
# For DeepseekV4 (compress_ratio > 1), the indexer KV cache stores
# compressed tokens. Convert uncompressed seq_lens to compressed.
if self.compress_ratio > 1:
seq_lens_cpu = seq_lens_cpu // self.compress_ratio
# True iff seq_lens aliases decode_seq_lens_buffer (flatten or
# native wrote it); False iff it aliases common_attn_metadata.
seq_lens_is_local_view = (use_native and next_n > 1) or (
Expand Down Expand Up @@ -619,6 +624,7 @@ def build(
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=block_table,
seq_lens=seq_lens,
max_seq_len=int(seq_lens_cpu.max().item()),
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
Expand Down Expand Up @@ -655,6 +661,7 @@ def build_prefill_chunk_metadata(
total_seq_lens = compressed_seq_lens_cpu[start_idx:end_idx].sum().item()
if total_seq_lens == 0:
return None
max_seq_len = int(compressed_seq_lens_cpu[start_idx:end_idx].max().item())

num_reqs = end_idx - start_idx
device = block_table.device
Expand Down Expand Up @@ -710,6 +717,7 @@ def build_prefill_chunk_metadata(
cu_seq_lens=cu_seq_lens,
token_to_seq=token_to_seq,
total_seq_lens=total_seq_lens,
max_seq_len=max_seq_len,
block_table=block_table[start_idx:end_idx],
token_start=token_start,
token_end=token_end,
Expand Down
Loading
Loading