-
Notifications
You must be signed in to change notification settings - Fork 0
DSV4 on MI300X: throughput tuning (optimise, +6 commits on port) #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
29aaf1b
266c6c8
27888ec
501595f
a767c4d
3309790
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,11 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import json | ||
| import os | ||
| from contextlib import contextmanager | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||
|
|
@@ -33,16 +38,159 @@ | |
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| _MOE_SHAPE_DUMP_COUNT = 0 | ||
| _MOE_SHAPE_DUMP_WARNED = False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Boolean flag without atomic protection. Why it matters: In a multi-threaded scenario, multiple threads could theoretically pass the Suggested fix: Consider using a module-level lock or accepting that duplicate warnings are harmless for this debug feature. |
||
| _ogs_opt_flags = None | ||
|
|
||
|
|
||
| def _env_int(name: str, default: int) -> int: | ||
| value = os.environ.get(name) | ||
| if value is None or value == "": | ||
| return default | ||
| return int(value) | ||
|
|
||
|
|
||
| def _dsv4_flash_rocm_ogs_constraints( | ||
| *, | ||
| m: int, | ||
| k: int, | ||
| n: int, | ||
| e: int, | ||
| topk: int, | ||
| activation: MoEActivation, | ||
| ) -> dict[str, int] | None: | ||
| if not current_platform.is_rocm(): | ||
| return None | ||
| if os.environ.get("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_TUNED", "1") == "0": | ||
| return None | ||
|
|
||
| # These are the high-throughput DeepSeek-V4-Flash routed-expert shapes on | ||
| # MI300X. The default OGS tile is 128x256x128; measured serving-shaped | ||
| # microbenchmarks are faster with a smaller M tile on CDNA3, including the | ||
| # prefill/ramp shapes seen in the fixed 512/512 benchmark. | ||
| if ( | ||
| m >= 512 | ||
| and k == 4096 | ||
| and n == 4096 | ||
| and e == 128 | ||
| and topk == 6 | ||
| and activation == MoEActivation.SILU | ||
| ): | ||
| default_block_m = 32 if m < 1024 else 64 | ||
| constraints = { | ||
| "block_m": _env_int( | ||
| "VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_M", default_block_m | ||
| ), | ||
| "block_n": _env_int("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_N", 128), | ||
| "block_k": _env_int("VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_BLOCK_K", 128), | ||
| } | ||
| if m >= 1024: | ||
| constraints["epilogue_subtile"] = _env_int( | ||
| "VLLM_ROCM_DSV4_FLASH_MXFP4_OGS_EPILOGUE_SUBTILE", 16 | ||
| ) | ||
| return constraints | ||
| return None | ||
|
|
||
|
|
||
| @contextmanager | ||
| def _temporary_ogs_constraints(constraints: dict[str, int] | None): | ||
| if not constraints or _ogs_opt_flags is None: | ||
| yield | ||
| return | ||
|
|
||
| previous = getattr(_ogs_opt_flags, "_opt_flags_constraints", {}).copy() | ||
| try: | ||
| _ogs_opt_flags.reset_opt_flags_constraints() | ||
| if previous: | ||
| _ogs_opt_flags.update_opt_flags_constraints(previous) | ||
| _ogs_opt_flags.update_opt_flags_constraints(constraints) | ||
| yield | ||
| finally: | ||
| _ogs_opt_flags.reset_opt_flags_constraints() | ||
| if previous: | ||
| _ogs_opt_flags.update_opt_flags_constraints(previous) | ||
|
|
||
|
|
||
| def _maybe_dump_dsv4_moe_shape( | ||
| *, | ||
| hidden_states: torch.Tensor, | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| topk_ids: torch.Tensor, | ||
| topk: int, | ||
| activation: MoEActivation, | ||
| global_num_experts: int, | ||
| ) -> None: | ||
| dump_dir = os.environ.get("DSV4_MOE_SHAPE_DUMP_DIR") | ||
| if not dump_dir: | ||
| return | ||
|
|
||
| # Host copies inside graph capture are illegal on ROCm and would also | ||
| # perturb the graph. Shape collection is an eager/profiling-only mode. | ||
| if torch.cuda.is_available() and torch.cuda.is_current_stream_capturing(): | ||
| return | ||
|
|
||
| global _MOE_SHAPE_DUMP_COUNT | ||
| limit = int(os.environ.get("DSV4_MOE_SHAPE_DUMP_LIMIT", "0") or "0") | ||
| if limit > 0 and _MOE_SHAPE_DUMP_COUNT >= limit: | ||
| return | ||
|
|
||
| stride = max(1, int(os.environ.get("DSV4_MOE_SHAPE_DUMP_STRIDE", "1") or "1")) | ||
| _MOE_SHAPE_DUMP_COUNT += 1 | ||
| if (_MOE_SHAPE_DUMP_COUNT - 1) % stride != 0: | ||
| return | ||
|
|
||
| min_m = int(os.environ.get("DSV4_MOE_SHAPE_DUMP_MIN_M", "0") or "0") | ||
| M, K = hidden_states.shape | ||
| if M < min_m: | ||
| return | ||
|
|
||
| try: | ||
| local_num_experts = int(w1.shape[0]) | ||
| valid_topk = topk_ids[topk_ids >= 0].reshape(-1) | ||
| hist = torch.bincount( | ||
| valid_topk.to(torch.int64), minlength=local_num_experts | ||
| )[:local_num_experts].cpu() | ||
| nonzero = hist[hist > 0] | ||
| if nonzero.numel() == 0: | ||
| p90_nonzero = 0 | ||
| hist_max = 0 | ||
| else: | ||
| p90_nonzero = int( | ||
| torch.quantile(nonzero.float(), 0.9).round().item() | ||
| ) | ||
| hist_max = int(nonzero.max().item()) | ||
|
|
||
| rec = { | ||
| "pid": os.getpid(), | ||
| "rank": os.environ.get("RANK"), | ||
| "local_rank": os.environ.get("LOCAL_RANK"), | ||
| "count": _MOE_SHAPE_DUMP_COUNT, | ||
| "activation": activation.name, | ||
| "M": int(M), | ||
| "K": int(K), | ||
| "topk": int(topk), | ||
| "global_num_experts": int(global_num_experts), | ||
| "local_num_experts": local_num_experts, | ||
| "w1_shape": list(w1.shape), | ||
| "w2_shape": list(w2.shape), | ||
| "hist_sum": int(hist.sum().item()), | ||
| "hist_nonzero": int(nonzero.numel()), | ||
| "hist_max": hist_max, | ||
| "p90_nonzero": p90_nonzero, | ||
| "hist": [int(x) for x in hist.tolist()], | ||
| } | ||
| path = Path(dump_dir) | ||
| path.mkdir(parents=True, exist_ok=True) | ||
| filename = f"moe_shapes_rank{rec['rank'] or 'x'}_pid{os.getpid()}.jsonl" | ||
| with (path / filename).open("a") as f: | ||
| f.write(json.dumps(rec, separators=(",", ":")) + "\n") | ||
| except Exception as e: | ||
| global _MOE_SHAPE_DUMP_WARNED | ||
| if not _MOE_SHAPE_DUMP_WARNED: | ||
| _MOE_SHAPE_DUMP_WARNED = True | ||
| logger.warning("Failed to dump DeepSeek V4 MoE shape: %s", e) | ||
|
|
||
|
|
||
| def _triton_kernel_moe_supports_current_device() -> bool: | ||
| # Shared device gate for the OAI Triton MoE expert classes. | ||
|
|
@@ -245,6 +393,7 @@ def _make_bitmatrix_metadata_pow2_safe(nonzero_indx, bitmatrix): | |
| if has_triton_kernels(): | ||
| try: | ||
| import triton_kernels.swiglu | ||
| import triton_kernels.matmul_ogs_details.opt_flags as _ogs_opt_flags | ||
| from triton_kernels.matmul_ogs import ( | ||
| FnSpecs, | ||
| FusedActivation, | ||
|
|
@@ -884,6 +1033,16 @@ def apply( | |
| if global_num_experts == -1: | ||
| global_num_experts = E | ||
|
|
||
| _maybe_dump_dsv4_moe_shape( | ||
| hidden_states=hidden_states, | ||
| w1=w1, | ||
| w2=w2, | ||
| topk_ids=topk_ids, | ||
| topk=topk, | ||
| activation=activation, | ||
| global_num_experts=global_num_experts, | ||
| ) | ||
|
|
||
| # Note that the output tensor might be in workspace13 | ||
| intermediate_cache1 = _resize_cache(workspace2, (batch_dim, M * topk, N)) | ||
| intermediate_cache3 = _resize_cache(workspace2, (batch_dim, M * topk, K)) | ||
|
|
@@ -892,18 +1051,94 @@ def apply( | |
|
|
||
| gammas = routing_data.gate_scal if routing_data else None | ||
|
|
||
| ogs_constraints = _dsv4_flash_rocm_ogs_constraints( | ||
| m=M, k=K, n=N, e=E, topk=topk, activation=activation | ||
| ) | ||
| with _temporary_ogs_constraints(ogs_constraints): | ||
| matmul_ogs( | ||
| hidden_states, | ||
| w1, | ||
| quant_config.w1_bias, | ||
| routing_data, | ||
| gather_indx=gather_indx, | ||
| precision_config=quant_config.w1_precision, | ||
| gammas=gammas if apply_router_weight_on_input else None, | ||
| fused_activation=None, | ||
| y=intermediate_cache1, | ||
| ) | ||
|
|
||
| sorted_token_ids_lora = None | ||
| expert_ids_lora = None | ||
| num_tokens_post_padded_lora = None | ||
| token_lora_mapping = None | ||
| lora_context = self._lora_context | ||
| if lora_context is None: | ||
| # W1 writes in expert-sorted order. The old no-LoRA path gathered | ||
| # back to token-topk order for activation, then gathered back to | ||
| # expert-sorted order for W2; those two gathers cancel. | ||
| self.activation( | ||
| activation, | ||
| intermediate_cache2, | ||
| intermediate_cache1.view(-1, N), | ||
| ) | ||
| with _temporary_ogs_constraints(ogs_constraints): | ||
| matmul_ogs( | ||
| intermediate_cache2, | ||
| w2, | ||
| quant_config.w2_bias, | ||
| routing_data, | ||
| scatter_indx=scatter_indx, | ||
| precision_config=quant_config.w2_precision, | ||
| gammas=None if apply_router_weight_on_input else gammas, | ||
| y=output, | ||
| ) | ||
| return | ||
|
|
||
| # w13 LoRA: gather the activation input from expert-sorted | ||
| # intermediate_cache1, then add the LoRA delta in-place on that copy | ||
| # before passing it to activation — exactly mirroring the old | ||
| # decorator approach which modified the gathered tensor in-place. | ||
| act_input = intermediate_cache1.view(-1, N)[gather_indx.dst_indx] | ||
| ( | ||
| sorted_token_ids_lora, | ||
| expert_ids_lora, | ||
| num_tokens_post_padded_lora, | ||
| token_lora_mapping, | ||
| ) = self.apply_w13_lora( | ||
| lora_context, | ||
| y=act_input, | ||
| x=hidden_states, | ||
| topk_ids=global_topk_ids, | ||
| topk_weights=topk_weights, | ||
| expert_map=expert_map, | ||
| w1=w1, | ||
| w2=w2, | ||
| num_tokens=M, | ||
| top_k_num=topk, | ||
| ) | ||
|
|
||
| self.activation( | ||
| activation, | ||
| intermediate_cache2, | ||
| act_input, | ||
| ) | ||
|
|
||
| # matmul_ogs grouped reduction fuses sum across multiple experts: | ||
| # y[dst_indx // n_expts_act, :] += x | ||
| # Set n_expts_act to 1 to unfuse the sum so we can do it manually via moe_sum. | ||
| routing_data.n_expts_act = 1 | ||
|
|
||
| with _temporary_ogs_constraints(ogs_constraints): | ||
| matmul_ogs( | ||
| intermediate_cache2[gather_indx.src_indx], | ||
| w2, | ||
| quant_config.w2_bias, | ||
| routing_data, | ||
| scatter_indx=scatter_indx, | ||
| precision_config=quant_config.w2_precision, | ||
| gammas=None if apply_router_weight_on_input else gammas, | ||
| y=intermediate_cache3, | ||
| ) | ||
|
|
||
| # w2 LoRA: after matmul_ogs with scatter_indx, intermediate_cache3 is | ||
| # in token-topk order, matching the (M, topk, K) layout add_lora_w2 expects. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -173,6 +173,7 @@ class DeepseekV32IndexerPrefillChunkMetadata: | |
| cu_seq_lens: torch.Tensor | ||
| token_to_seq: torch.Tensor | ||
| total_seq_lens: int | ||
| max_seq_len: int | ||
| token_start: int | ||
| token_end: int | ||
| num_reqs: int | ||
|
|
@@ -192,6 +193,7 @@ class DeepSeekV32IndexerDecodeMetadata: | |
| # - native MTP path: 2D (B, next_n) where [b,j] = L_b - next_n + j + 1 | ||
| # Both fp8_fp4_paged_mqa_logits and the topk kernels accept both shapes. | ||
| seq_lens: torch.Tensor | ||
| max_seq_len: int | ||
| decode_lens: torch.Tensor | ||
| requires_padding: bool | ||
| schedule_metadata: torch.Tensor | ||
|
|
@@ -553,6 +555,7 @@ def build( | |
|
|
||
| decode_metadata = None | ||
| if num_decodes > 0: | ||
| assert common_attn_metadata.seq_lens_cpu_upper_bound is not None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-blocking: Assertion depends on upstream metadata builder. Why it matters: This assertion (and the similar one at line 513) assumes Suggested fix: No change required. The assertion is good defensive programming. Consider adding a comment referencing where this field is populated (e.g., |
||
| torch.diff( | ||
| common_attn_metadata.query_start_loc[: num_decodes + 1], | ||
| out=self.decode_lens_buffer[:num_decodes], | ||
|
|
@@ -563,6 +566,7 @@ def build( | |
| ) | ||
|
|
||
| seq_lens = common_attn_metadata.seq_lens[:num_decodes] | ||
| seq_lens_cpu = common_attn_metadata.seq_lens_cpu_upper_bound[:num_decodes] | ||
| block_table = common_attn_metadata.block_table_tensor[:num_decodes, ...] | ||
|
|
||
| max_decode_len = int(decode_lens_cpu.max().item()) | ||
|
|
@@ -587,6 +591,7 @@ def build( | |
| # For DeepseekV4 (compress_ratio > 1), the indexer KV cache stores | ||
| # compressed tokens. Convert uncompressed seq_lens to compressed. | ||
| if self.compress_ratio > 1: | ||
| seq_lens_cpu = seq_lens_cpu // self.compress_ratio | ||
| # True iff seq_lens aliases decode_seq_lens_buffer (flatten or | ||
| # native wrote it); False iff it aliases common_attn_metadata. | ||
| seq_lens_is_local_view = (use_native and next_n > 1) or ( | ||
|
|
@@ -619,6 +624,7 @@ def build( | |
| decode_metadata = DeepSeekV32IndexerDecodeMetadata( | ||
| block_table=block_table, | ||
| seq_lens=seq_lens, | ||
| max_seq_len=int(seq_lens_cpu.max().item()), | ||
| decode_lens=decode_lens, | ||
| requires_padding=requires_padding, | ||
| schedule_metadata=self.scheduler_metadata_buffer, | ||
|
|
@@ -655,6 +661,7 @@ def build_prefill_chunk_metadata( | |
| total_seq_lens = compressed_seq_lens_cpu[start_idx:end_idx].sum().item() | ||
| if total_seq_lens == 0: | ||
| return None | ||
| max_seq_len = int(compressed_seq_lens_cpu[start_idx:end_idx].max().item()) | ||
|
|
||
| num_reqs = end_idx - start_idx | ||
| device = block_table.device | ||
|
|
@@ -710,6 +717,7 @@ def build_prefill_chunk_metadata( | |
| cu_seq_lens=cu_seq_lens, | ||
| token_to_seq=token_to_seq, | ||
| total_seq_lens=total_seq_lens, | ||
| max_seq_len=max_seq_len, | ||
| block_table=block_table[start_idx:end_idx], | ||
| token_start=token_start, | ||
| token_end=token_end, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-blocking: Module-level global counter without explicit synchronization.
Why it matters: While Python's GIL protects simple integer increment operations, this counter could exhibit unexpected behavior in multi-process scenarios (each process gets its own copy) or if the module is reloaded. For the intended profiling use case, this is acceptable.
Suggested fix: If cross-process coordination is ever needed, consider using a file-based counter or multiprocessing.Value. For now, a comment documenting the expected single-process-per-GPU pattern would suffice.