diff --git a/collector/vllm/collect_mla_module.py b/collector/vllm/collect_mla_module_v1.py similarity index 100% rename from collector/vllm/collect_mla_module.py rename to collector/vllm/collect_mla_module_v1.py diff --git a/collector/vllm/collect_mla_module_v2.py b/collector/vllm/collect_mla_module_v2.py new file mode 100644 index 00000000..5ed58343 --- /dev/null +++ b/collector/vllm/collect_mla_module_v2.py @@ -0,0 +1,1042 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +__compat__ = "vllm>=0.17.0" + +""" +MLA Module Collector for vLLM — unified MLA and DSA benchmarking. + +Profiles the complete attention module forward pass (projections + attention + +output), not just the bare attention kernel. Uses vLLM's own modeling code to +construct a single `DeepseekV2MLAAttention` module with dummy weights, then +benchmarks its forward. + +MLA vs DSA is determined by the presence of `index_topk` in the HF config. +Op names and data schema are aligned with TRT-LLM's collect_mla_module.py +so that queries can be reused across frameworks. + +Supported models and their attention types are defined in SUPPORTED_MODELS. +The collector reads a real HF config, overrides the layer-local shape fields +in-memory, and then instantiates just the attention module. + +Usage: + # MLA context phase (DeepSeek-V3 style) + python collect_mla_module.py --mode context --model mla + + # DSA generation phase (DeepSeek-V3.2 style) + python collect_mla_module.py --mode generation --model dsa + + # All models, context phase + python collect_mla_module.py --mode context + + # Quick single-point test + python collect_mla_module.py --mode context --model mla --quick --batch-size 4 --seq-len 2048 +""" + +import argparse +import gc +import json +import math +import os +import tempfile +import traceback +from pathlib import Path + +import torch +from vllm.config import set_current_vllm_config +from vllm.forward_context import set_forward_context +from vllm.platforms import current_platform + +# ═══════════════════════════════════════════════════════════════════════ +# Config registry patch — vLLM 0.16.0 registers the GlmMoeDsaForCausalLM +# model class but omits the config-type mapping for "glm_moe_dsa", so +# AutoConfig.from_pretrained() fails. The config layout is identical to +# DeepSeek-V3 (GlmMoeDsaForCausalLM inherits DeepseekV2ForCausalLM), so +# reusing DeepseekV3Config is safe. +# ═══════════════════════════════════════════════════════════════════════ +from vllm.transformers_utils.config import _CONFIG_REGISTRY +from vllm.version import __version__ as vllm_version + +from collector.helper import benchmark_with_power, get_sm_version, log_perf +from collector.vllm.utils import ( + BatchSpec, + create_and_prepopulate_kv_cache_mla, + create_common_attn_metadata, + create_vllm_config, + setup_distributed, + with_exit_stack, +) + +if "glm_moe_dsa" not in _CONFIG_REGISTRY: + _CONFIG_REGISTRY["glm_moe_dsa"] = "DeepseekV3Config" + + +# ═══════════════════════════════════════════════════════════════════════ +# Local model config resolution — avoid HuggingFace Hub downloads +# ═══════════════════════════════════════════════════════════════════════ + +# Pre-cached HF configs live in src/aiconfigurator/model_configs/ as +# "--_config.json". vLLM's ModelConfig accepts a local +# directory containing config.json, so we create a temp dir with a +# symlink when the cached file exists. +_MODEL_CONFIGS_DIR = Path(__file__).resolve().parents[2] / "src" / "aiconfigurator" / "model_configs" + +# Cache of model_name -> temp dir path (created once per process). +_local_config_cache: dict[str, str] = {} + + +def _resolve_model_path(model_name: str) -> str: + """Return a local directory path for *model_name* if a cached config exists, else return model_name as-is.""" + if model_name in _local_config_cache: + return _local_config_cache[model_name] + + config_file = _MODEL_CONFIGS_DIR / f"{model_name.replace('/', '--')}_config.json" + if not config_file.exists(): + return model_name + + # Create a temp directory with config.json so vLLM's ModelConfig + # loads from disk instead of downloading from HuggingFace Hub. + tmp_dir = tempfile.mkdtemp(prefix=f"aic_model_{model_name.replace('/', '_')}_") + os.symlink(config_file, os.path.join(tmp_dir, "config.json")) + # Strip auto_map if present. Some models (e.g. DeepSeek-V3) ship + # config.json with auto_map pointing to a custom Python config class + # (configuration_deepseek.py). HuggingFace's AutoConfig.from_pretrained() + # — called by vLLM's ModelConfig — unconditionally tries to import that + # module from the model directory where it doesn't exist; vLLM natively + # supports these architectures and only needs the JSON fields. + with open(config_file) as f: + config_data = json.load(f) + if "auto_map" in config_data: + config_data.pop("auto_map") + os.remove(os.path.join(tmp_dir, "config.json")) + with open(os.path.join(tmp_dir, "config.json"), "w") as f: + json.dump(config_data, f) + + # Also symlink hf_quant_config.json if present (used by quantized models). + quant_file = _MODEL_CONFIGS_DIR / f"{model_name.replace('/', '--')}_hf_quant_config.json" + if quant_file.exists(): + os.symlink(quant_file, os.path.join(tmp_dir, "hf_quant_config.json")) + + _local_config_cache[model_name] = tmp_dir + return tmp_dir + + +# ═══════════════════════════════════════════════════════════════════════ +# Supported Models — model_path → attention type +# ═══════════════════════════════════════════════════════════════════════ + +SUPPORTED_MODELS: dict[str, str] = { + "deepseek-ai/DeepSeek-V3": "mla", + "deepseek-ai/DeepSeek-V3.2": "dsa", + "zai-org/GLM-5": "dsa", +} + + +# ═══════════════════════════════════════════════════════════════════════ +# Test Cases — aligned with TRT-LLM's collect_mla_module.py +# ═══════════════════════════════════════════════════════════════════════ + + +def _get_precision_combos(phase: str): + """Return (compute_dtype, kv_cache_dtype, gemm_type) triples for a phase. + + Each triple describes the full quantisation configuration for one + benchmark sweep. GPU capability (SM version) determines which + combos are available. + + Precision axes: + gemm_type — linear-layer GEMMs (projections inside the module) + bfloat16: always + fp8_block: SM >= 89 (Ada / Hopper / Blackwell) + nvfp4: SM >= 100 (Blackwell) + + (compute_dtype, kv_cache_dtype) — attention compute + KV cache + context: (bf16, bf16) always; (fp8, fp8) SM >= 100 + generation: (bf16, bf16) always; (bf16, fp8) SM >= 90 + """ + sm = get_sm_version() + + gemm_types = ["bfloat16"] + if sm >= 89: + gemm_types.append("fp8_block") + if sm >= 100: + gemm_types.append("nvfp4") + + attn_combos = [("bfloat16", "bfloat16")] + if phase == "context": + if sm >= 100: + attn_combos.append(("fp8", "fp8")) + else: + if sm >= 90: + attn_combos.append(("bfloat16", "fp8")) + + return [(c, kv, g) for g in gemm_types for c, kv in attn_combos] + + +def get_context_test_cases(attn_type: str): + """Context-phase test cases. + + Returns list of [seq_len, batch_size, num_heads, kv_cache_dtype, + compute_dtype, gemm_type, perf_filename]. + """ + cases = [] + b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] + s_list = [1, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 10240, 12288, 16384, 32768] + base_fname = f"{attn_type}_context_module_perf.txt" + for compute_dtype, kv_dtype, gemm_type in _get_precision_combos("context"): + for num_heads in [128, 64, 32, 16, 8, 4, 2, 1]: + for b in b_list: + for s in s_list: + if b * s > 131072: + continue + cases.append([s, b, num_heads, kv_dtype, compute_dtype, gemm_type, base_fname]) + return cases + + +def get_generation_test_cases(attn_type: str): + """Generation-phase test cases. + + Returns list of [kv_cache_len, batch_size, num_heads, kv_cache_dtype, + compute_dtype, gemm_type, perf_filename]. + """ + cases = [] + b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024] + s_list = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072] + base_fname = f"{attn_type}_generation_module_perf.txt" + for compute_dtype, kv_dtype, gemm_type in _get_precision_combos("generation"): + for num_heads in [128, 64, 32, 16, 8, 4, 2, 1]: + for b in b_list: + for s in s_list: + if b * s > 1024 * 4096 * 2 * 2 * 2: + continue + cases.append([s, b, num_heads, kv_dtype, compute_dtype, gemm_type, base_fname]) + return cases + + +def _build_module_test_cases(attn_type: str, mode: str): + """Build module-level test cases for a specific attention type and phase. + + Output format: [seq_len, batch_size, num_heads, kv_cache_dtype, + compute_dtype, gemm_type, perf_filename, model_path, attn_type] + """ + base_cases = get_context_test_cases(attn_type) if mode == "context" else get_generation_test_cases(attn_type) + model_paths = [m for m, t in SUPPORTED_MODELS.items() if t == attn_type] + cases = [] + for model_path in model_paths: + for s, b, h, kv_dtype, compute_dtype, gemm_type, fname in base_cases: + cases.append([s, b, h, kv_dtype, compute_dtype, gemm_type, fname, model_path, attn_type]) + return cases + + +def get_mla_context_module_test_cases(): + """collect.py entrypoint for MLA context module collection.""" + return _build_module_test_cases(attn_type="mla", mode="context") + + +def get_mla_generation_module_test_cases(): + """collect.py entrypoint for MLA generation module collection.""" + return _build_module_test_cases(attn_type="mla", mode="generation") + + +def get_dsa_context_module_test_cases(): + """collect.py entrypoint for DSA context module collection.""" + return _build_module_test_cases(attn_type="dsa", mode="context") + + +def get_dsa_generation_module_test_cases(): + """collect.py entrypoint for DSA generation module collection.""" + return _build_module_test_cases(attn_type="dsa", mode="generation") + + +# ═══════════════════════════════════════════════════════════════════════ +# Module Construction +# ═══════════════════════════════════════════════════════════════════════ + + +def _create_gemm_quant_config(gemm_type: str): + """Create the vLLM QuantizationConfig for a given gemm_type. + + Returns None for bfloat16 (unquantised GEMMs). + For fp8_block / nvfp4, returns an online-quantisation config so that + dummy BF16 weights are dynamically quantised during + ``process_weights_after_loading``. + """ + if gemm_type == "bfloat16": + return None + if gemm_type == "fp8_block": + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + # vLLM requires is_checkpoint_fp8_serialized=True for block-scaled + # FP8 (fp8.py raises ValueError otherwise). This routes through + # Fp8LinearMethod (block_quant=True) → W8A8BlockFp8LinearOp → + # DeepGEMM on SM≥89. + return Fp8Config( + is_checkpoint_fp8_serialized=True, + activation_scheme="dynamic", + weight_block_size=[128, 128], + ) + if gemm_type == "nvfp4": + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4Config, + ) + + return ModelOptNvFp4Config( + is_checkpoint_nvfp4_serialized=True, + kv_cache_quant_algo=None, + exclude_modules=[], + ) + raise ValueError(f"Unknown gemm_type: {gemm_type!r}") + + +def _create_attention_module( + model_path: str, + attn_type: str, + num_heads: int, + use_fp8_kv_cache: bool, + use_prefill_fp8: bool, + max_seq_len: int, + max_batch_size: int, + gemm_type: str = "bfloat16", + device: str = "cuda:0", + is_context: bool = True, +): + """ + Create a DeepseekV2MLAAttention module from vLLM's own modeling code. + + Loads a real HF config from model_path, overrides the layer-local attention + dimensions we want to benchmark in-memory, and then constructs the module + with dummy weights. The module includes all projections + attention + + output. + + Args: + model_path: HuggingFace model path (e.g. "deepseek-ai/DeepSeek-V3.2"). + attn_type: Attention type ("mla" or "dsa"). + use_prefill_fp8: When True and on SM100+, enable FP8 prefill + attention via ``attention_config.use_prefill_query_quantization``. + gemm_type: Precision for linear-layer GEMMs — "bfloat16", + "fp8_block", or "nvfp4". + """ + from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLAAttention + + local_model_path = _resolve_model_path(model_path) + + block_size = 64 + max_model_len = max(max_seq_len + 1, 4096) + num_kv_cache_blocks = max( + 1 + math.ceil((max_seq_len + 1) / block_size) * max_batch_size, + 8192, + ) + + # Determine kv cache dtype string for sparse MLA. + # For DSA (DeepSeekV3.2), fp8 uses the custom ``fp8_ds_mla`` 656-byte + # cache format (512B quantized NoPE + 16B scales + 128B RoPE). + # For dense MLA, standard fp8 (fp8_e4m3) is used. + is_dsa = attn_type == "dsa" + + vllm_config = create_vllm_config( + model_name=local_model_path, + max_model_len=max_model_len, + block_size=block_size, + num_gpu_blocks=num_kv_cache_blocks, + max_num_seqs=max_batch_size, + max_num_batched_tokens=max(max_batch_size * max_seq_len, 131072) if is_context else max_batch_size, + use_fp8_kv_cache=use_fp8_kv_cache, + trust_remote_code=True, + ) + + # Override quant_config to control linear-layer GEMM precision. + # DeepSeek-V3.2 ships with FP8 quantisation by default, so we + # must always set quant_config explicitly: None for bf16, + # Fp8Config (blockwise) for fp8_block, ModelOptNvFp4Config for nvfp4. + vllm_config.quant_config = _create_gemm_quant_config(gemm_type) + + # For DSA, mirror the DeepseekV32ForCausalLM.verify_and_update_config() + # logic: fp8 cache must use ``fp8_ds_mla`` format. + if is_dsa and use_fp8_kv_cache: + vllm_config.cache_config.cache_dtype = "fp8_ds_mla" + + # Enable FP8 prefill attention on SM100+ (Blackwell). + # This quantizes Q/K/V to FP8 before sending to the prefill kernel. + if use_prefill_fp8: + vllm_config.attention_config.use_prefill_query_quantization = True + + # Override just the layer-local dimensions we sweep in the collector. + hf_config = vllm_config.model_config.hf_config + hf_config.num_hidden_layers = 1 + hf_config.num_attention_heads = num_heads + hf_config.num_key_value_heads = num_heads + + # Create topk_indices_buffer for DSA + topk_indices_buffer = None + if is_dsa and hasattr(hf_config, "index_topk"): + max_tokens = vllm_config.scheduler_config.max_num_batched_tokens + topk_indices_buffer = torch.empty( + max_tokens, + hf_config.index_topk, + dtype=torch.int32, + device=device, + ) + + # Build the attention module inside set_current_vllm_config() context. + # FP8 quantized Linear layers (QuantFP8 / CustomOp) call + # get_current_vllm_config() during __init__, so the config must be set. + # set_default_torch_dtype is required because MLAAttention.__init__ + # calls torch.get_default_dtype() to select the attention backend + # (MLA backends only support bfloat16, not float32). + from vllm.utils.torch_utils import set_default_torch_dtype + + with set_current_vllm_config(vllm_config), set_default_torch_dtype(vllm_config.model_config.dtype): + attn_module = DeepseekV2MLAAttention( + vllm_config=vllm_config, + config=hf_config, + hidden_size=hf_config.hidden_size, + num_heads=num_heads, + qk_nope_head_dim=hf_config.qk_nope_head_dim, + qk_rope_head_dim=hf_config.qk_rope_head_dim, + v_head_dim=hf_config.v_head_dim, + q_lora_rank=hf_config.q_lora_rank if hasattr(hf_config, "q_lora_rank") else None, + kv_lora_rank=hf_config.kv_lora_rank, + max_position_embeddings=hf_config.max_position_embeddings, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + prefix="model.layers.0.self_attn", + topk_indices_buffer=topk_indices_buffer, + ) + + # Serialized block-scaled FP8 creates weight params on meta device; + # to() cannot copy meta tensors, so use to_empty() when needed. + if any(p.is_meta for p in attn_module.parameters()): + attn_module = attn_module.to_empty(device=torch.device(device)) + else: + attn_module = attn_module.to(device) + attn_module.eval() + attn_module.requires_grad_(False) + + # Initialize with random weights. + # FP8 weights → zero (safe dummy value). + # Scale params → 1.0 (avoid NaN during process_weights_after_loading). + # Everything else → small constant. + # + # Deterministic init — vLLM 0.17.0 DSA modules leave CUDA graph RNG + # offset tracking active after construction (likely from FlashInfer + # sparse MLA backend, vllm-project/vllm#33451 / vllm-project/vllm#34457). + # Any RNG call (normal_, uniform_, randn) crashes with "Offset increment + # outside graph capture". Using fill_() is safe because kernel latency + # depends on shapes/dtypes, not values, and dummy weights are overwritten + # by process_weights_after_loading() anyway. + # See: https://github.com/vllm-project/vllm/issues/39371 + with torch.no_grad(): + for name, tensor in list(attn_module.named_parameters()) + list(attn_module.named_buffers()): + if tensor.is_meta: + continue + if tensor.dtype in (torch.float8_e4m3fn, torch.float8_e5m2, torch.uint8): + tensor.data.zero_() + elif tensor.dtype == torch.float32 and "scale" in name: + tensor.data.fill_(0.5) + else: + tensor.data.fill_(0.01) + + return attn_module, vllm_config + + +def _process_module_weights(attn_module, vllm_config, device): + """Process weights after loading, mimicking vLLM's model loader. + + This must be called after module construction to: + 1. Run FP8 quantization on linear layer weights. + 2. Create W_UK_T and W_UV matrices in MLAAttention that are + required for the forward pass. + """ + from vllm.model_executor.layers.attention.mla_attention import MLAAttention + from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase + + with set_current_vllm_config(vllm_config): + # 1. Process quantized linear layers (FP8 weight conversion). + for _, module in attn_module.named_modules(): + quant_method = getattr(module, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + quant_method.process_weights_after_loading(module) + + # 2. Process MLAAttention layers (creates W_UK_T, W_UV). + for _, module in attn_module.named_modules(): + if isinstance(module, MLAAttention) and hasattr(module, "process_weights_after_loading"): + module.process_weights_after_loading(vllm_config.model_config.dtype) + + +def _create_context_kv_inputs(batch_spec: BatchSpec, kv_lora_rank: int, qk_rope_head_dim: int, device: str): + """ + Create cached KV tensors for the tokens that already exist in the paged KV + cache before the current forward() call. + + Context phase: cache is empty because all tokens are processed in the same + prefill step. + Generation phase: cache holds seq_len - 1 historical tokens, and the + current forward processes exactly 1 new token per request. + """ + kv_c_contexts = [] + k_pe_contexts = [] + for seq_len, query_len in zip(batch_spec.seq_lens, batch_spec.query_lens, strict=True): + context_len = max(0, int(seq_len) - int(query_len)) + kv_c_contexts.append(torch.full((context_len, kv_lora_rank), 0.01, dtype=torch.bfloat16, device=device)) + k_pe_contexts.append(torch.full((context_len, 1, qk_rope_head_dim), 0.01, dtype=torch.bfloat16, device=device)) + return kv_c_contexts, k_pe_contexts + + +def _populate_indexer_kv_cache( + indexer_kv_cache: torch.Tensor, + common_attn_metadata, + context_lens: list[int], +) -> None: + """ + Populate the DSA indexer cache so generation benchmarks see a realistic + historical K cache instead of an all-zero buffer. + """ + block_table = common_attn_metadata.block_table_tensor + block_size = indexer_kv_cache.shape[1] + entry_dim = indexer_kv_cache.shape[2] + device = indexer_kv_cache.device + + for i, context_len in enumerate(context_lens): + if context_len <= 0: + continue + token_offsets = torch.arange(context_len, dtype=torch.long, device=device) + block_indices = token_offsets // block_size + intra_block_offsets = token_offsets % block_size + block_ids = block_table[i, block_indices] + dummy_cache = torch.full((context_len, entry_dim), 42, dtype=torch.uint8, device=device) + indexer_kv_cache[block_ids, intra_block_offsets, :] = dummy_cache + + +# ═══════════════════════════════════════════════════════════════════════ +# KV Cache + Metadata +# ═══════════════════════════════════════════════════════════════════════ + + +def _create_kv_cache_and_metadata( + vllm_config, + attn_type: str, + batch_size: int, + seq_len: int, + num_heads: int, + is_context: bool, + use_fp8_kv_cache: bool, + device: str = "cuda:0", +): + """Create KV cache and attention metadata for benchmarking.""" + from vllm.v1.kv_cache_interface import MLAAttentionSpec + + hf_config = vllm_config.model_config.hf_config + kv_lora_rank = hf_config.kv_lora_rank + qk_rope_head_dim = hf_config.qk_rope_head_dim + head_dim = kv_lora_rank + qk_rope_head_dim + block_size = vllm_config.cache_config.block_size + is_dsa = attn_type == "dsa" + + if is_context: + batch_spec = BatchSpec( + seq_lens=[seq_len] * batch_size, + query_lens=[seq_len] * batch_size, + ) + else: + batch_spec = BatchSpec( + seq_lens=[seq_len] * batch_size, + query_lens=[1] * batch_size, + ) + + num_kv_cache_blocks = max( + 1 + math.ceil((seq_len + 1) / block_size) * batch_size, + 8192, + ) + + common_attn_metadata = create_common_attn_metadata( + batch_spec, block_size, torch.device(device), arange_block_indices=True + ) + + # Select the correct dtype for cache. + # DSA fp8 uses a custom 656-byte ``fp8_ds_mla`` cache format that + # stores quantised NoPE + per-128-element scales + BF16 RoPE. + # Dense MLA fp8 uses standard fp8_e4m3. + if is_dsa and use_fp8_kv_cache: + cache_dtype = current_platform.fp8_dtype() + kv_cache_dtype_str = "fp8_ds_mla" + elif use_fp8_kv_cache: + cache_dtype = current_platform.fp8_dtype() + kv_cache_dtype_str = "fp8" + else: + cache_dtype = torch.bfloat16 + kv_cache_dtype_str = None + + # Populate KV cache with the tokens that exist before this forward. + kv_c_contexts, k_pe_contexts = _create_context_kv_inputs( + batch_spec=batch_spec, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + device=device, + ) + + kv_cache = create_and_prepopulate_kv_cache_mla( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=block_size, + head_size=head_dim, + dtype=cache_dtype, + device=torch.device(device), + num_blocks=num_kv_cache_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=False, + kv_cache_dtype=kv_cache_dtype_str, + ) + + # Build attention metadata via backend builder + backend_cls = _get_attention_backend(vllm_config, head_dim, use_fp8_kv_cache, is_dsa) + builder_cls = backend_cls.get_builder_cls() + + kv_cache_spec = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, # MLA uses 1 KV head + head_size=head_dim, + dtype=cache_dtype, + sliding_window=None, + cache_dtype_str=kv_cache_dtype_str, + ) + + attn_layer_name = "model.layers.0.self_attn.attn" + layer_names = [attn_layer_name] + builder = builder_cls(kv_cache_spec, layer_names, vllm_config, torch.device(device)) + attn_metadata = builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + + # For DSA, the Indexer has its own KV cache and metadata builder. + indexer_kv_cache = None + indexer_metadata = None + if is_dsa: + from vllm.v1.attention.backends.mla.indexer import ( + DeepseekV32IndexerBackend, + ) + + index_head_dim = hf_config.index_head_dim + quant_block_size = 128 + indexer_head_dim = index_head_dim + index_head_dim // quant_block_size * 4 + + indexer_layer_name = "model.layers.0.self_attn.indexer.k_cache" + indexer_spec = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=1, + head_size=indexer_head_dim, + dtype=torch.uint8, + ) + indexer_kv_cache = torch.zeros( + num_kv_cache_blocks, + block_size, + indexer_head_dim, + dtype=torch.uint8, + device=device, + ) + indexer_builder_cls = DeepseekV32IndexerBackend.get_builder_cls() + indexer_builder = indexer_builder_cls(indexer_spec, [indexer_layer_name], vllm_config, torch.device(device)) + indexer_metadata = indexer_builder.build( + common_prefix_len=0, + common_attn_metadata=common_attn_metadata, + ) + _populate_indexer_kv_cache( + indexer_kv_cache=indexer_kv_cache, + common_attn_metadata=common_attn_metadata, + context_lens=[tensor.shape[0] for tensor in kv_c_contexts], + ) + + return kv_cache, attn_metadata, common_attn_metadata, indexer_kv_cache, indexer_metadata + + +def _get_attention_backend(vllm_config, head_dim, use_fp8_kv_cache, is_dsa): + """Select attention backend based on GPU capability and config. + + The backend selector uses kv_cache_dtype to pick the right implementation: + - DSA fp8 → ``fp8_ds_mla`` (FlashMLA Sparse custom format) + - MLA fp8 → ``fp8`` (standard fp8_e4m3) + - BF16 → None / "auto" + """ + dtype = torch.bfloat16 + + # Compute the kv_cache_dtype token the selector expects. + if is_dsa and use_fp8_kv_cache: + kv_cache_dtype_val = "fp8_ds_mla" + elif use_fp8_kv_cache: + kv_cache_dtype_val = "fp8" + else: + kv_cache_dtype_val = None + + from vllm.utils.import_utils import resolve_obj_by_qualname + from vllm.v1.attention.selector import AttentionSelectorConfig + + attn_selector_config = AttentionSelectorConfig( + head_size=head_dim, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype_val, + block_size=vllm_config.cache_config.block_size, + use_mla=True, + has_sink=False, + use_sparse=is_dsa, + ) + backend = current_platform.get_attn_backend_cls(None, attn_selector_config) + + return resolve_obj_by_qualname(backend) + + +# ═══════════════════════════════════════════════════════════════════════ +# Benchmark Runner +# ═══════════════════════════════════════════════════════════════════════ + + +@with_exit_stack +def run_mla_module( + exit_stack, + seq_len: int, + batch_size: int, + num_heads: int, + kv_cache_dtype: str, + compute_dtype: str, + gemm_type: str, + perf_filename: str, + *, + model_path: str, + attn_type: str, + device: str = "cuda:0", + warming_up: int = 10, + test_ite: int = 6, +): + """Run a single MLA / DSA module-level benchmark point.""" + setup_distributed(device) + torch.cuda.set_device(device) + + # DSA's sparse_attn_indexer requires a WorkspaceManager. + try: + from vllm.v1.worker.workspace import init_workspace_manager + + init_workspace_manager(torch.device(device)) + except (ImportError, RuntimeWarning): + pass + + use_fp8_kv_cache = kv_cache_dtype == "fp8" + use_prefill_fp8 = compute_dtype == "fp8" + is_context = "context" in perf_filename + phase = "context" if is_context else "generation" + variant = attn_type.upper() + print( + f"\n[{variant} module] {phase} b={batch_size}, s={seq_len}, " + f"heads={num_heads}, gemm={gemm_type}, compute={compute_dtype}, kv={kv_cache_dtype}, model={model_path}" + ) + + # 1. Create attention module + attn_module, vllm_config = _create_attention_module( + model_path=model_path, + attn_type=attn_type, + num_heads=num_heads, + use_fp8_kv_cache=use_fp8_kv_cache, + use_prefill_fp8=use_prefill_fp8, + max_seq_len=seq_len, + max_batch_size=batch_size, + gemm_type=gemm_type, + device=device, + is_context=is_context, + ) + + # 1b. Process weights (FP8 quantization + create W_UK_T / W_UV for MLA) + _process_module_weights(attn_module, vllm_config, device) + + # 2. Create KV cache + metadata + with set_current_vllm_config(vllm_config): + kv_cache, attn_metadata, _, indexer_kv_cache, indexer_metadata = _create_kv_cache_and_metadata( + vllm_config=vllm_config, + attn_type=attn_type, + batch_size=batch_size, + seq_len=seq_len, + num_heads=num_heads, + is_context=is_context, + use_fp8_kv_cache=use_fp8_kv_cache, + device=device, + ) + + # 2b. Bind KV cache to the attention layer so forward() can access it. + # MLAAttention registers itself in static_forward_context during + # __init__, and reads self.kv_cache[virtual_engine] during forward. + attn_layer_name = "model.layers.0.self_attn.attn" + forward_ctx = vllm_config.compilation_config.static_forward_context + forward_ctx[attn_layer_name].kv_cache = [kv_cache] + + # For DSA, also bind the indexer's KV cache. + indexer_layer_name = "model.layers.0.self_attn.indexer.k_cache" + if indexer_kv_cache is not None and indexer_layer_name in forward_ctx: + forward_ctx[indexer_layer_name].kv_cache = [indexer_kv_cache] + + # 3. Input tensors + hidden_size = vllm_config.model_config.hf_config.hidden_size + if is_context: + num_tokens = seq_len * batch_size + positions = ( + torch.arange(seq_len, device=device, dtype=torch.long) + .unsqueeze(0) + .expand(batch_size, -1) + .reshape(-1) + .contiguous() + ) + else: + num_tokens = batch_size + positions = torch.full( + (batch_size,), + seq_len - 1, + device=device, + dtype=torch.long, + ) + + hidden_states = torch.full( + (num_tokens, hidden_size), + 0.01, + dtype=torch.bfloat16, + device=device, + ) + + # 4. Dry run + # set_current_vllm_config — needed by quantised layers and RoPE. + # set_forward_context — provides attn_metadata + kv_cache to the + # MLAAttention.forward() path (it calls get_forward_context()). + exit_stack.enter_context(set_current_vllm_config(vllm_config)) + attn_metadata_dict = {attn_layer_name: attn_metadata} + if indexer_metadata is not None: + attn_metadata_dict[indexer_layer_name] = indexer_metadata + exit_stack.enter_context(set_forward_context(attn_metadata_dict, vllm_config)) + try: + with torch.inference_mode(): + attn_module.forward(positions, hidden_states, None) + except Exception as e: + print(f" Dry run failed: {e}") + traceback.print_exc() + _cleanup() + return None + + # 5. Benchmark + def kernel_func(): + attn_module.forward(positions, hidden_states, None) + + with benchmark_with_power( + device=torch.device(device), + kernel_func=kernel_func, + num_warmups=warming_up, + num_runs=test_ite, + repeat_n=1, + allow_graph_fail=True, + ) as results: + pass + + latency = results["latency_ms"] + + # 6. Log results — schema aligned with TRT-LLM + if is_context: + isl = seq_len + step = 0 + else: + isl = 1 + step = seq_len + + op_name = f"{attn_type}_{phase}_module" + + # Record architecture to distinguish different DSA models in the perf CSV. + # perf_database uses this as a dict key when loading data. + # Aligns with sdk/models.py which uses architectures[0] throughout. + hf_cfg = vllm_config.model_config.hf_config + architecture = getattr(hf_cfg, "architectures", [getattr(hf_cfg, "model_type", "unknown")])[0] + + log_perf( + item_list=[ + { + "model": model_path, + "architecture": architecture, + "mla_dtype": "float16" if compute_dtype == "bfloat16" else compute_dtype, + "kv_cache_dtype": "float16" if kv_cache_dtype == "bfloat16" else kv_cache_dtype, + "gemm_type": "float16" if gemm_type == "bfloat16" else gemm_type, + "num_heads": num_heads, + "batch_size": batch_size, + "isl": isl, + "tp_size": 1, + "step": step, + "latency": f"{latency:.4f}", + } + ], + framework="VLLM", + version=vllm_version, + device_name=torch.cuda.get_device_name(device), + op_name=op_name, + kernel_source="default", + perf_filename=perf_filename, + power_stats=results["power_stats"], + ) + + print( + f" [{phase}] b={batch_size}, s={seq_len}, heads={num_heads}, " + f"gemm={gemm_type}, compute={compute_dtype}, kv={kv_cache_dtype}: {latency:.4f} ms" + ) + + _cleanup() + return latency + + +def run_mla_module_worker( + seq_len: int, + batch_size: int, + num_heads: int, + kv_cache_dtype: str, + compute_dtype: str, + gemm_type: str, + perf_filename: str, + model_path: str, + attn_type: str, + device: str = "cuda:0", +): + """Worker-compatible positional wrapper used by collector/collect.py.""" + return run_mla_module( + seq_len=seq_len, + batch_size=batch_size, + num_heads=num_heads, + kv_cache_dtype=kv_cache_dtype, + compute_dtype=compute_dtype, + gemm_type=gemm_type, + perf_filename=perf_filename, + model_path=model_path, + attn_type=attn_type, + device=device, + ) + + +def _cleanup(): + torch.cuda.empty_cache() + gc.collect() + + +# ═══════════════════════════════════════════════════════════════════════ +# CLI +# ═══════════════════════════════════════════════════════════════════════ + + +def main(): + model_names = list(SUPPORTED_MODELS.keys()) + + parser = argparse.ArgumentParser( + description="MLA/DSA module-level collector for vLLM", + ) + parser.add_argument("--mode", choices=["context", "generation"], required=True) + parser.add_argument( + "--model", + type=str, + default=None, + choices=model_names, + help=f"Model to benchmark. If not specified, runs all: {model_names}", + ) + parser.add_argument("--num-heads", type=int, default=None, help="Filter by number of heads") + parser.add_argument("--batch-size", type=int, default=None, help="Single batch size (for --quick)") + parser.add_argument("--seq-len", type=int, default=None, help="Single seq len (for --quick)") + parser.add_argument( + "--kv-cache-dtype", + type=str, + choices=["bfloat16", "fp8"], + default=None, + help="KV cache dtype (default: run both bfloat16 and fp8 when GPU supports it)", + ) + parser.add_argument( + "--compute-dtype", + type=str, + choices=["bfloat16", "fp8"], + default=None, + help="Compute dtype for attention (default: auto based on phase and GPU)", + ) + parser.add_argument( + "--gemm-type", + type=str, + choices=["bfloat16", "fp8_block", "nvfp4"], + default=None, + help="GEMM quantisation type for linear layers (default: run all supported by GPU)", + ) + parser.add_argument("--device", type=str, default="cuda:0") + parser.add_argument("--quick", action="store_true", help="Quick single-point test") + args = parser.parse_args() + + # Select models to run + if args.model: + models_to_run = {args.model: SUPPORTED_MODELS[args.model]} + else: + models_to_run = SUPPORTED_MODELS + + for model_path, attn_type in models_to_run.items(): + print(f"\n{'=' * 60}") + print(f"Model: {model_path} | Attention: {attn_type.upper()}") + print(f"{'=' * 60}") + + if args.quick: + b = args.batch_size or 4 + s = args.seq_len or 2048 + h = args.num_heads or 128 + kv_dtype = args.kv_cache_dtype or "bfloat16" + compute = args.compute_dtype or "bfloat16" + gemm = args.gemm_type or "bfloat16" + fname = f"{attn_type}_{args.mode}_module_perf.txt" + run_mla_module( + seq_len=s, + batch_size=b, + num_heads=h, + kv_cache_dtype=kv_dtype, + compute_dtype=compute, + gemm_type=gemm, + perf_filename=fname, + model_path=model_path, + attn_type=attn_type, + device=args.device, + ) + continue + + if args.mode == "context": + test_cases = get_context_test_cases(attn_type=attn_type) + else: + test_cases = get_generation_test_cases(attn_type=attn_type) + + if args.num_heads is not None: + test_cases = [tc for tc in test_cases if tc[2] == args.num_heads] + + if args.kv_cache_dtype is not None: + test_cases = [tc for tc in test_cases if tc[3] == args.kv_cache_dtype] + + if args.compute_dtype is not None: + test_cases = [tc for tc in test_cases if tc[4] == args.compute_dtype] + + if args.gemm_type is not None: + test_cases = [tc for tc in test_cases if tc[5] == args.gemm_type] + + print(f"Running {len(test_cases)} {args.mode} {attn_type.upper()} module test cases...") + + for i, (s, b, h, kv_dtype, compute, gemm, fname) in enumerate(test_cases): + print(f"[{i + 1}/{len(test_cases)}]", end="") + try: + run_mla_module( + seq_len=s, + batch_size=b, + num_heads=h, + kv_cache_dtype=kv_dtype, + compute_dtype=compute, + gemm_type=gemm, + perf_filename=fname, + model_path=model_path, + attn_type=attn_type, + device=args.device, + ) + except torch.cuda.OutOfMemoryError: + print(f" OOM: b={b}, s={s}, heads={h}, gemm={gemm}, compute={compute}, kv={kv_dtype}") + torch.cuda.empty_cache() + gc.collect() + except Exception as e: + print(f" FAILED: b={b}, s={s}, heads={h}, gemm={gemm}, compute={compute}, kv={kv_dtype}: {e}") + traceback.print_exc() + torch.cuda.empty_cache() + gc.collect() + + +if __name__ == "__main__": + main() diff --git a/collector/vllm/collect_moe.py b/collector/vllm/collect_moe_v1.py similarity index 100% rename from collector/vllm/collect_moe.py rename to collector/vllm/collect_moe_v1.py diff --git a/collector/vllm/collect_moe_v2.py b/collector/vllm/collect_moe_v2.py new file mode 100644 index 00000000..8c32dc35 --- /dev/null +++ b/collector/vllm/collect_moe_v2.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +__compat__ = "vllm>=0.17.0" + +import os + +import torch +import torch.nn.functional as F +from vllm.model_executor.layers.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config +from vllm.model_executor.layers.fused_moe.layer import determine_expert_map +from vllm.version import __version__ as vllm_version + +# Compatibility: block FP8 helpers may differ by version. +# Priority: vllm.utils.deep_gemm -> deep_gemm extension -> None. +try: + from vllm.utils.deep_gemm import per_block_cast_to_fp8 +except Exception: + try: + import deep_gemm # type: ignore + + per_block_cast_to_fp8 = getattr(deep_gemm, "per_block_cast_to_fp8", None) + except Exception: + per_block_cast_to_fp8 = None # type: ignore[assignment] + +# vLLM >= 0.14.0 raises AssertionError in get_current_vllm_config() when called +# outside a set_current_vllm_config() context (https://github.com/vllm-project/vllm/pull/31747). +# vLLM's custom ops (e.g. _vllm_ops.scaled_fp4_quant) requires vllm config to decide how to dispatch. +from vllm.config import VllmConfig, set_current_vllm_config + +# NVFP4 support: requires Blackwell (SM>=100) and FlashInfer TRTLLM FP4 kernel. +trtllm_fp4_block_scale_routed_moe = None +_vllm_ops = None +prepare_static_weights_for_trtllm_fp4_moe = None +_nvfp4_available = False +try: + from flashinfer.fused_moe import trtllm_fp4_block_scale_routed_moe # type: ignore[assignment] + from vllm import _custom_ops as _vllm_ops # type: ignore[assignment] + from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + prepare_static_weights_for_trtllm_fp4_moe, # type: ignore[assignment] + ) + + _nvfp4_available = True +except Exception: + trtllm_fp4_block_scale_routed_moe = None + _vllm_ops = None + prepare_static_weights_for_trtllm_fp4_moe = None + +# MXFP4 support: uses vLLM's high-level FusedMoE module with Mxfp4Config. +# This lets vLLM handle backend selection (FlashInfer/Triton/Marlin) and +# weight swizzle internally, so one code path works on all GPUs. +_mxfp4_available = False +try: + from vllm.model_executor.layers.fused_moe.layer import FusedMoE + from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config + + _mxfp4_available = True +except Exception: + pass + +from vllm.forward_context import set_forward_context + +from collector.common_test_cases import get_common_moe_test_cases +from collector.helper import balanced_logits, benchmark_with_power, get_sm_version, log_perf, power_law_logits_v3 + +aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112 + + +def get_moe_test_cases(): + """Generate MoE test cases""" + + # Quantization types supported by vLLM + moe_list = ["float16"] + if get_sm_version() > 86: + moe_list += ["fp8"] + if get_sm_version() >= 90 and per_block_cast_to_fp8 is not None: + moe_list += ["fp8_block"] + if get_sm_version() >= 100 and _nvfp4_available: + moe_list += ["nvfp4"] + if _mxfp4_available: + moe_list += ["w4a16_mxfp4"] + + _gpt_oss_models = {"openai/gpt-oss-20b", "openai/gpt-oss-120b"} + + test_cases = [] + + for common_moe_testcase in get_common_moe_test_cases(): + model_name = common_moe_testcase.model_name + + # vllm does not support TP when EP is enabled. + if common_moe_testcase.tp > 1 and common_moe_testcase.ep > 1: + continue + + for moe_type in moe_list: + # GPT-OSS models only use mxfp4 quantization in production; + # skip them for other quant types. + if model_name in _gpt_oss_models and moe_type != "w4a16_mxfp4": + continue + # Conversely, mxfp4 is only collected for GPT-OSS models. + if moe_type == "w4a16_mxfp4" and model_name not in _gpt_oss_models: + continue + + # fp8_block requires hidden_size divisible by block group_size (128) + if moe_type == "fp8_block" and ( + common_moe_testcase.hidden_size % 128 != 0 + or (common_moe_testcase.inter_size // common_moe_testcase.tp) % 128 != 0 + ): + continue + + # nvfp4 uses TRTLLM FP4 kernel which has stricter constraints: + # - hidden_size must be divisible by 512 (GEMM tiling requirement) + # - local_inter_size (inter_size // tp) must be divisible by 64 + # (GEMM1 N = 2*local_inter must be multiple of 128, GEMM2 K must be multiple of 64) + # - topk must be <= 10 (MaxNumTopExperts in routing kernel) + if moe_type == "nvfp4" and ( + common_moe_testcase.hidden_size % 512 != 0 + or (common_moe_testcase.inter_size // common_moe_testcase.tp) % 64 != 0 + or common_moe_testcase.topk > 10 + ): + continue + + # w4a16_mxfp4 requires dimensions aligned to group_size (32) + if moe_type == "w4a16_mxfp4" and ( + common_moe_testcase.hidden_size % 32 != 0 + or (common_moe_testcase.inter_size // common_moe_testcase.tp) % 32 != 0 + ): + continue + + test_cases.append( + [ + moe_type, + common_moe_testcase.num_tokens_list, + common_moe_testcase.hidden_size, + common_moe_testcase.inter_size, + common_moe_testcase.topk, + common_moe_testcase.num_experts, + common_moe_testcase.tp, + common_moe_testcase.ep, + common_moe_testcase.model_name, + "moe_perf.txt", + common_moe_testcase.token_expert_distribution, + common_moe_testcase.power_law_alpha, + ] + ) + + return test_cases + + +def run_moe_torch( + moe_type, + num_tokens_lists, + hidden_size, + inter_size, + topk, + num_experts, + moe_tp_size, + moe_ep_size, + model_name, + perf_filename, + distributed="power_law", + power_law_alpha=0.0, + device="cuda:0", +): + """Run vLLM MoE performance benchmarking""" + torch.cuda.set_device(device) + torch.set_default_device(device) + + # Configure quantization parameters + dtype = torch.float16 + quant_config = None + block_shape: list[int] | None = None + a1_scale = None + a2_scale = None + + # Calculate local number of experts + local_inter_size = inter_size // moe_tp_size + local_num_experts, expert_map, _ = determine_expert_map(moe_ep_size, 0, num_experts) + + # Create weight tensors + # w1: gate + up projection weights [num_experts, 2 * inter_size, hidden_size] + # w2: down projection weights [num_experts, hidden_size, inter_size] + w1 = torch.randn( + local_num_experts, + 2 * local_inter_size, + hidden_size, + dtype=torch.float16, + device=device, + ) + w2 = torch.randn( + local_num_experts, + hidden_size, + local_inter_size, + dtype=torch.float16, + device=device, + ) + + # MXFP4 path: uses vLLM's high-level FusedMoE module with Mxfp4Config. + # vLLM handles backend selection (FlashInfer/Triton/Marlin) and weight swizzle. + # + # We keep a reference to the VllmConfig used during construction because + # vLLM 0.17.0's MoERunner (vllm-project/vllm#32344) calls + # get_forward_context() → get_layer_from_name() during forward, which + # looks up the module in static_forward_context. FusedMoE registers + # itself there during __init__, so we must pass the *same* config to + # set_forward_context() at benchmark time. + use_mxfp4 = moe_type == "w4a16_mxfp4" + moe_module = None + mxfp4_vllm_cfg = None + + if use_mxfp4: + if not _mxfp4_available: + raise ImportError("MXFP4 MoE requires vllm >= 0.17.0 with Mxfp4Config support.") + + mxfp4_quant_config = Mxfp4Config() + + # pcp_size=1: vLLM 0.17.0 added prefill context parallel to FusedMoE + # (vllm-project/vllm#32344); without it, __init__ calls get_pcp_group() + # which requires distributed init. + mxfp4_vllm_cfg = VllmConfig() + with set_current_vllm_config(mxfp4_vllm_cfg): + moe_module = FusedMoE( + num_experts=num_experts, + top_k=topk, + hidden_size=hidden_size, + intermediate_size=inter_size, + reduce_results=False, + renormalize=True, + quant_config=mxfp4_quant_config, + tp_size=moe_tp_size, + dp_size=1, + ep_size=moe_ep_size, + prefix="", + has_bias=True, # GPT-OSS uses bias + activation="swigluoai", # GPT-OSS activation + pcp_size=1, + ) + moe_module.to(device) + moe_module.eval() + moe_module.requires_grad_(False) + + # Fill synthetic mxfp4 weights (uint8 packed, E2M1 format) + with torch.no_grad(): + moe_module.w13_weight.data.random_(0, 255) + moe_module.w2_weight.data.random_(0, 255) + moe_module.w13_weight_scale.data.random_(0, 255) + moe_module.w2_weight_scale.data.random_(0, 255) + if hasattr(moe_module, "w13_bias"): + moe_module.w13_bias.data.normal_() + if hasattr(moe_module, "w2_bias"): + moe_module.w2_bias.data.normal_() + + # Trigger backend selection + weight swizzle for current GPU + moe_module.quant_method.process_weights_after_loading(moe_module) + + # Free float16 weights; not used for mxfp4. + del w1, w2 + + # NVFP4 path: uses FlashInfer TRTLLM FP4 monolithic kernel (not fused_experts). + use_nvfp4 = moe_type == "nvfp4" + nvfp4_data: dict | None = None + + if use_nvfp4: + _missing = [ + name + for name, obj in [ + ("trtllm_fp4_block_scale_routed_moe", trtllm_fp4_block_scale_routed_moe), + ("_vllm_ops", _vllm_ops), + ("prepare_static_weights_for_trtllm_fp4_moe", prepare_static_weights_for_trtllm_fp4_moe), + ] + if obj is None + ] + if _missing: + raise ImportError( + f"NVFP4 MoE requires flashinfer and vllm >= 0.14.0 with FP4 support, but the following " + f"could not be imported: {', '.join(_missing)}. " + f"Install a compatible flashinfer build and ensure vllm >= 0.14.0 with FP4 support." + ) + + # Raw packed FP4 weights and block scales + w1_raw = torch.randint( + 0, 255, (local_num_experts, 2 * local_inter_size, hidden_size // 2), dtype=torch.uint8, device=device + ) + w2_raw = torch.randint( + 0, 255, (local_num_experts, hidden_size, local_inter_size // 2), dtype=torch.uint8, device=device + ) + w1_scale_raw = torch.ones( + local_num_experts, 2 * local_inter_size, hidden_size // 16, dtype=torch.float8_e4m3fn, device=device + ) + w2_scale_raw = torch.ones( + local_num_experts, hidden_size, local_inter_size // 16, dtype=torch.float8_e4m3fn, device=device + ) + + # Shuffle weights and scales for TRTLLM kernel layout + w1_shuf, w1_scale_shuf, w2_shuf, w2_scale_shuf = prepare_static_weights_for_trtllm_fp4_moe( + w1_raw, + w2_raw, + w1_scale_raw, + w2_scale_raw, + hidden_size=hidden_size, + intermediate_size=local_inter_size, + num_experts=local_num_experts, + is_gated_activation=True, + ) + del w1_raw, w2_raw, w1_scale_raw, w2_scale_raw + + # Per-expert scales + a13_scale = torch.ones(local_num_experts, dtype=torch.float32, device=device) + a2_scale_nvfp4 = torch.ones(local_num_experts, dtype=torch.float32, device=device) + w13_scale_2 = torch.ones(local_num_experts, dtype=torch.float32, device=device) + w2_scale_2 = torch.ones(local_num_experts, dtype=torch.float32, device=device) + + nvfp4_data = dict( + w1=w1_shuf, + w1_scale=w1_scale_shuf, + w2=w2_shuf, + w2_scale=w2_scale_shuf, + g1_scale_c=a13_scale * w13_scale_2 / a2_scale_nvfp4, + a1_gscale=1.0 / a13_scale, + g1_alphas=a13_scale * w13_scale_2, + g2_alphas=a2_scale_nvfp4 * w2_scale_2, + ) + # Free the float16 weights; they are not used for nvfp4. + del w1, w2 + + elif moe_type in ["fp8", "fp8_block"]: + dtype = torch.float8_e4m3fn + if moe_type == "fp8_block": + block_shape = [128, 128] + + if per_block_cast_to_fp8 is None: + raise ImportError("per_block_cast_to_fp8 is unavailable; fp8_block requires a newer vLLM build.") + + w1_scale_list = [] + w2_scale_list = [] + w1_q = torch.empty_like(w1, dtype=dtype) + w2_q = torch.empty_like(w2, dtype=dtype) + for i in range(local_num_experts): + w1_q[i], w1_scale_i = per_block_cast_to_fp8(w1[i], block_size=block_shape, use_ue8m0=True) + w2_q[i], w2_scale_i = per_block_cast_to_fp8(w2[i], block_size=block_shape, use_ue8m0=True) + w1_scale_list.append(w1_scale_i) + w2_scale_list.append(w2_scale_i) + w1 = w1_q + w2 = w2_q + w1_scale = torch.stack(w1_scale_list) + w2_scale = torch.stack(w2_scale_list) + else: + w1_scale = torch.randn(local_num_experts, dtype=torch.float32, device=device) + w2_scale = torch.randn(local_num_experts, dtype=torch.float32, device=device) + a1_scale = torch.randn(1, dtype=torch.float32, device=device) + a2_scale = torch.randn(1, dtype=torch.float32, device=device) + + quant_config = fp8_w8a8_moe_quant_config( + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + block_shape=block_shape, + ) + + if not use_mxfp4 and dtype == torch.float8_e4m3fn: + w1 = w1.to(dtype) + w2 = w2.to(dtype) + + # Performance testing for each token count + for num_tokens_idx, num_tokens in enumerate(num_tokens_lists): + print("num_tokens", num_tokens) + print("topk", topk) + hs_dtype = torch.bfloat16 if use_mxfp4 else torch.float16 + hidden_states = torch.randn([num_tokens, hidden_size], dtype=hs_dtype, device=device) + + # Generate routing inputs. + # mxfp4 path uses FusedMoE.forward(hidden_states, router_logits) which does + # routing internally; other paths need pre-computed topk_weights/topk_ids. + num_iter = 5 if distributed == "power_law" else 1 + if use_mxfp4: + # FusedMoE.forward() takes raw router logits (num_tokens, num_experts) + if distributed == "power_law": + actual_logits_list = [ + power_law_logits_v3(num_tokens, num_experts, topk, moe_ep_size, power_law_alpha) + .to(torch.bfloat16) + .to(device) + for _ in range(num_iter) + ] + elif distributed == "balanced": + actual_logits = balanced_logits(num_tokens, num_experts, topk).to(torch.bfloat16).to(device) + else: + raise ValueError(f"Unsupported distributed mode: {distributed}") + elif distributed == "power_law": + topk_weights_list = [] + topk_ids_list = [] + + for _ in range(num_iter): + logits = ( + power_law_logits_v3( + num_tokens, + num_experts, + topk, + moe_ep_size, + power_law_alpha, + ) + .half() + .to(device) + ) + weights, ids = torch.topk(logits, topk, dim=-1) + topk_weights_list.append(F.softmax(weights, dim=-1)) + topk_ids_list.append(ids) + + print("actual num_tokens: ", [topk_ids.shape[0] for topk_ids in topk_ids_list]) + + elif distributed == "balanced": + actual_logits = balanced_logits(num_tokens, num_experts, topk).half().to(device) + topk_weights, topk_ids = torch.topk(actual_logits, topk, dim=-1) + topk_weights = F.softmax(topk_weights, dim=-1) + + else: + raise ValueError(f"Unsupported distributed mode: {distributed}") + + num_warmups = 3 + num_runs = 6 + if distributed == "power_law": + num_warmups = 1 + num_runs = 1 + + def _run_nvfp4_once(hs, tw, ti): + """Run a single nvfp4 MoE iteration via FlashInfer TRTLLM FP4 kernel.""" + # Quantize input to FP4 + x_fp4, x_scale = _vllm_ops.scaled_fp4_quant( + hs.to(torch.bfloat16), + nvfp4_data["a1_gscale"][0:1], + is_sf_swizzled_layout=False, + ) + num_tok = x_fp4.shape[0] + scale_cols = hs.shape[1] // 16 + # Pack topk: (expert_id << 16) | bf16_weight_as_int16 + packed = (ti.to(torch.int32) << 16) | tw.to(torch.bfloat16).view(torch.int16).to(torch.int32) + trtllm_fp4_block_scale_routed_moe( + topk_ids=packed, + routing_bias=None, + hidden_states=x_fp4, + hidden_states_scale=x_scale.view(num_tok, scale_cols).to(torch.float8_e4m3fn), + gemm1_weights=nvfp4_data["w1"], + gemm1_weights_scale=nvfp4_data["w1_scale"].view(torch.float8_e4m3fn), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=nvfp4_data["w2"], + gemm2_weights_scale=nvfp4_data["w2_scale"].view(torch.float8_e4m3fn), + gemm2_bias=None, + output1_scale_scalar=nvfp4_data["g1_scale_c"], + output1_scale_gate_scalar=nvfp4_data["g1_alphas"], + output2_scale_scalar=nvfp4_data["g2_alphas"], + num_experts=num_experts, + top_k=topk, + n_group=0, + topk_group=0, + intermediate_size=local_inter_size, + local_expert_offset=0, + local_num_experts=local_num_experts, + routed_scaling_factor=None, + routing_method_type=1, # Renormalize + do_finalize=True, + ) + + def run_single_iteration(): + if use_mxfp4: + # FusedMoE.forward(hidden_states, router_logits) does routing internally. + if distributed == "power_law": + for logits in actual_logits_list: + moe_module.forward(hidden_states[: logits.shape[0]], logits[: logits.shape[0]]) + else: + moe_module.forward(hidden_states, actual_logits) + elif use_nvfp4: + if distributed == "power_law": + for tw, ti in zip(topk_weights_list, topk_ids_list, strict=True): + _run_nvfp4_once(hidden_states[: tw.shape[0]], tw, ti) + else: + _run_nvfp4_once(hidden_states, topk_weights, topk_ids) + elif distributed == "power_law": + for i, (tw, ti) in enumerate(zip(topk_weights_list, topk_ids_list, strict=True)): + local_num_tokens = tw.shape[0] + _ = fused_experts( + hidden_states[:local_num_tokens], + w1, + w2, + tw, + ti, + inplace=False, + quant_config=quant_config, + global_num_experts=num_experts, + expert_map=expert_map, + ) + else: + _ = fused_experts( + hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=False, + quant_config=quant_config, + global_num_experts=num_experts, + expert_map=expert_map, + ) + + def run_iterations(): + # Use benchmark_with_power context manager + with benchmark_with_power( + device=device, + kernel_func=run_single_iteration, + num_warmups=num_warmups, + num_runs=num_runs, + repeat_n=1, + allow_graph_fail=True, + ) as results: + pass + + return results["latency_ms"] / num_iter, results["power_stats"] + + try: + vllm_cfg = mxfp4_vllm_cfg if use_mxfp4 else VllmConfig() + with set_current_vllm_config(vllm_cfg), set_forward_context({}, vllm_cfg): + latency, power_stats = run_iterations() + except torch.OutOfMemoryError: + # If OOM, check if we had at least one successful run. + if num_tokens_idx > 0: + break + raise + + print(f"moe latency: {latency}") + + if use_mxfp4: + source = "vllm_mxfp4_moe" + elif use_nvfp4: + source = "vllm_flashinfer_trtllm_moe_fp4" + else: + source = "vllm_fused_moe" + + log_perf( + item_list=[ + { + "moe_dtype": moe_type, + "num_tokens": num_tokens, + "hidden_size": hidden_size, + "inter_size": inter_size, + "topk": topk, + "num_experts": num_experts, + "moe_tp_size": moe_tp_size, + "moe_ep_size": moe_ep_size, + "distribution": "power_law_" + str(power_law_alpha) if distributed == "power_law" else distributed, + "latency": latency, + } + ], + framework="VLLM", + version=vllm_version, + device_name=torch.cuda.get_device_name(device), + op_name="moe", + kernel_source=source, + perf_filename=perf_filename, + power_stats=power_stats, + ) + + +if __name__ == "__main__": + test_cases = get_moe_test_cases() + print(f"Total test cases: {len(test_cases)}") + + for test_case in test_cases[:4]: + print(f"Running test case: {test_case}") + try: + run_moe_torch(*test_case) + except Exception as e: + print(f"Test case failed: {test_case}") + print(f"Error: {e}") + continue diff --git a/collector/vllm/registry.py b/collector/vllm/registry.py index a1e47262..e7ac2498 100644 --- a/collector/vllm/registry.py +++ b/collector/vllm/registry.py @@ -4,11 +4,13 @@ """ Declarative registry mapping ops to collector modules for vLLM. -No version forks exist yet. When vLLM API changes require a fork, -add a ``versions`` tuple following the trtllm registry pattern. +For versioned entries, ``versions`` is a tuple of :class:`VersionRoute` in +**descending** order. The resolver picks the first route whose min_version +is <= the runtime version. To add support for a new vLLM version: + add a new VersionRoute at the top of the versions tuple. """ -from collector.registry_types import OpEntry +from collector.registry_types import OpEntry, VersionRoute REGISTRY: list[OpEntry] = [ OpEntry( @@ -31,9 +33,12 @@ ), OpEntry( op="moe", - module="collector.vllm.collect_moe", get_func="get_moe_test_cases", run_func="run_moe_torch", + versions=( + VersionRoute("0.17.0", "collector.vllm.collect_moe_v2"), + VersionRoute("0.0.0", "collector.vllm.collect_moe_v1"), + ), ), OpEntry( op="mla_context", @@ -49,27 +54,39 @@ ), OpEntry( op="mla_context_module", - module="collector.vllm.collect_mla_module", get_func="get_mla_context_module_test_cases", run_func="run_mla_module_worker", + versions=( + VersionRoute("0.17.0", "collector.vllm.collect_mla_module_v2"), + VersionRoute("0.0.0", "collector.vllm.collect_mla_module_v1"), + ), ), OpEntry( op="mla_generation_module", - module="collector.vllm.collect_mla_module", get_func="get_mla_generation_module_test_cases", run_func="run_mla_module_worker", + versions=( + VersionRoute("0.17.0", "collector.vllm.collect_mla_module_v2"), + VersionRoute("0.0.0", "collector.vllm.collect_mla_module_v1"), + ), ), OpEntry( op="dsa_context_module", - module="collector.vllm.collect_mla_module", get_func="get_dsa_context_module_test_cases", run_func="run_mla_module_worker", + versions=( + VersionRoute("0.17.0", "collector.vllm.collect_mla_module_v2"), + VersionRoute("0.0.0", "collector.vllm.collect_mla_module_v1"), + ), ), OpEntry( op="dsa_generation_module", - module="collector.vllm.collect_mla_module", get_func="get_dsa_generation_module_test_cases", run_func="run_mla_module_worker", + versions=( + VersionRoute("0.17.0", "collector.vllm.collect_mla_module_v2"), + VersionRoute("0.0.0", "collector.vllm.collect_mla_module_v1"), + ), ), OpEntry( op="gdn", diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_context_module_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_context_module_perf.txt index d766f613..0553c1df 100644 --- a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_context_module_perf.txt +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_context_module_perf.txt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:37f795075e094edf7f618fb8f37b4531893740b5a2a802c0232f34de619d8601 -size 67799 +oid sha256:a804d1d476feebc1a71aa8c6580918fdac813ebf8a15d57950130c9385e98c90 +size 1257999 diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_generation_module_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_generation_module_perf.txt index 6e5c7d67..00b8d1df 100644 --- a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_generation_module_perf.txt +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/dsa_generation_module_perf.txt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d739ac792e2f0e23269c85f530f92235f0438adba32b4cca1e5ed54028f9417b -size 2000605 +oid sha256:c6d6822e5489b8c4dec2ba69c731f53a1f804e007b8757c6984e4cb2a6c75214 +size 2099980 diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/mla_context_module_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/mla_context_module_perf.txt new file mode 100644 index 00000000..62292657 --- /dev/null +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/mla_context_module_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:170f1cc5da124d1dd45731a3423a00113c1037b958d11f2e34d097bace7f6ac5 +size 752294 diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/mla_generation_module_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/mla_generation_module_perf.txt new file mode 100644 index 00000000..f42c7a4c --- /dev/null +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/mla_generation_module_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39e2c41106d6b10a6d6e6b0790324c9b214dc6a6c1c7571a48a5e7a10bc538f0 +size 820417 diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/moe_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/moe_perf.txt index fc331b24..76ec72b7 100644 --- a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/moe_perf.txt +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.17.0/moe_perf.txt @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4d997e4d3e2cf43c4065ec9c92a741c05b4062d12815581fb06f40ab080d68c9 -size 4219365 +oid sha256:09f8ab80d076704e872c1a7051c3ceb9a11bd76b8ca0e675a6ed31630817f3c8 +size 4219410 diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.19.0/gemm_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.19.0/gemm_perf.txt new file mode 100644 index 00000000..99ea94fc --- /dev/null +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.19.0/gemm_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1d74c42e07e4bf04db6cbce6d06d503957cb14c2e8eb5ba634cfb486412eaf9 +size 9602794 diff --git a/src/aiconfigurator/systems/data/b200_sxm/vllm/0.19.0/moe_perf.txt b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.19.0/moe_perf.txt new file mode 100644 index 00000000..75b149e8 --- /dev/null +++ b/src/aiconfigurator/systems/data/b200_sxm/vllm/0.19.0/moe_perf.txt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f8b89a18b154bd6d33018c10b16041a4947ecff437886e6d7d9b241ec0f9d440 +size 8432718