Skip to content

Conversation

@yyihuang
Copy link
Collaborator

@yyihuang yyihuang commented Nov 13, 2025

📌 Description

Refactor fused_moe test.

Split test on model+precision.

Part [1]:

  • test deepseek (kimi, lite) fp8 block-scaled fused moe
  • default TP8
  • PDL enabled
  • MajorK weight layout
  • higher tolerance and matching percentage

Next Part [2]:

  • add BlockMajorK weight layout

Next Part [x]:

  • Per Tensor FP8 MoE, FP4MoE

later:

  • refactor llama4, topk?, renormalize? routing tests

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Tests
    • Added a comprehensive FP8 block-scale fused Mixture-of-Experts test validating end-to-end correctness across many routing, expert and precision configurations. Includes randomized inputs, per-token/per-expert workflows, extensive parameterizations, diagnostic statistics, autotune-path checks, and a minimal sanity run.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 13, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds a new FP8 block-scale fused MoE test module implementing FP8 quantization helpers, a reference fused-MoE run path (dequant, per-expert block scaling, no-aux routing, local-expert compute), randomized input generation, and a parameterized test comparing a fused kernel to the reference.

Changes

Cohort / File(s) Summary
FP8 Fused MoE Test Module
tests/moe/test_dpsk_fused_moe_fp8.py
New test file adding: run() reference (FP8 dequantization, per-expert block scaling, DeepSeek-V3 no-aux routing: sigmoid+bias, group top-k, keep groups, global top-k), local-expert compute path (GEMM1 → SwiGLU → GEMM2) with per-token accumulation; FP8 helpers (_fp8_block_quant_1d, _fp8_block_quant_2d), utilities (next_power_of_2, get_tile_tokens_dim), input generator generate_random_inputs_moe(), parameterized test test_correctness_dpsk_fp8_fused_moe() with diagnostics and kernel vs. reference comparisons.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test Harness
    participant Gen as Input Generator
    participant Ref as Reference Impl
    participant Kernel as Fused Kernel
    participant Comp as Comparator

    Test->>Gen: request randomized FP8 inputs, weights, scales
    Gen-->>Test: routing_logits, routing_bias, hidden_states, scales, quantized_weights

    rect rgb(230,245,255)
      Note over Ref: Reference flow (dequant → scale → routing → local-expert compute)
      Test->>Ref: inputs & weights
      Ref->>Ref: Dequantize FP8 activations & weights
      Ref->>Ref: Apply per-expert block scaling
      Ref->>Ref: Compute routing scores: sigmoid(logits + bias)
      Ref->>Ref: Group top-k selection → keep groups → global top-k among kept
      Ref->>Ref: For routed tokens: GEMM1 → SwiGLU → GEMM2, accumulate outputs per token
      Ref-->>Test: reference_output
    end

    rect rgb(245,230,250)
      Note over Kernel: Fused kernel execution (optional autotune)
      Test->>Kernel: same inputs & weights
      Kernel->>Kernel: Execute fused FP8 MoE kernel
      Kernel-->>Test: kernel_output
    end

    Test->>Comp: compare outputs (abs/rel, cosine, MSE, hit ratio)
    Comp-->>Test: diagnostics & assertion result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Pay attention to FP8 quantization helpers and scale management.
  • Verify routing implementation: sigmoid+bias, group top-k and global top-k across kept groups.
  • Check per-token accumulation, local_expert_offset and tiled token handling.
  • Review statistical comparisons, autotune path and deterministic input generation.

Possibly related PRs

Suggested reviewers

  • yzh119
  • jiahanc
  • djmmoss
  • wenscarl
  • cyx-6

Poem

🐰 I hop through FP8 fields of light,
routing tokens through day and night,
experts chosen, outputs blend,
reference checks the kernel's end,
tiny hops, big tests — delight! 🎉

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: refactoring the dpsk fused_moe test with part [1] designation, which matches the addition of the new FP8 block-scaled test module.
Description check ✅ Passed The description provides adequate context about the refactoring effort, clearly identifies the PR as part [1] of a series, outlines the specific FP8 block-scaled fused MoE test additions, and confirms pre-commit checks and tests are complete.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@yyihuang yyihuang changed the title refactor: update fused_moe test [1] refactor: update dpsk fused_moe test [1] Nov 14, 2025
@yyihuang yyihuang marked this pull request as ready for review November 14, 2025 23:02
@yyihuang yyihuang requested review from Copilot and yzh119 November 14, 2025 23:02
Copilot finished reviewing on behalf of yyihuang November 14, 2025 23:04
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors and adds new tests for DeepSeek (Kimi, Lite) FP8 block-scaled fused MoE operations, splitting tests by model and precision configurations.

Key changes:

  • Adds comprehensive FP8 block-scale MoE testing for DeepSeek variants (V3, Lite, Kimi-K2)
  • Implements reference implementation with FP8 block quantization and DeepSeek-V3 no-aux routing
  • Includes parametrized tests with multiple sequence lengths, expert configurations, and intermediate sizes

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 247 to 268
scales = torch.empty(
(*prefix, nb_r, nb_c), dtype=torch.float32, device=w_bf16.device
)

it = np.ndindex(*prefix) if prefix else [()]
for idx in it:
sel = idx if isinstance(idx, tuple) else (idx,)
for i in range(nb_r):
rs = slice(i * block, (i + 1) * block)
for j in range(nb_c):
cs = slice(j * block, (j + 1) * block)
blk = w_f32[(*sel, rs, cs)] # [128, 128]
amax = torch.amax(torch.abs(blk))
s = (
(amax / max_fp8)
if amax > 0
else torch.tensor(1.0, device=w_bf16.device)
)
q = (blk / s).to(torch.float8_e4m3fn)
w_fp8[(*sel, rs, cs)] = q
scales[(*sel, i, j)] = s
return w_fp8, scales
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Performance concern: The nested loops for block-wise quantization iterate over each block individually (lines 217-226 and 255-268). For large tensors, this could be slow. Consider vectorizing these operations using torch.chunk() or reshape/view operations to process blocks in parallel, which would be more efficient.

Suggested change
scales = torch.empty(
(*prefix, nb_r, nb_c), dtype=torch.float32, device=w_bf16.device
)
it = np.ndindex(*prefix) if prefix else [()]
for idx in it:
sel = idx if isinstance(idx, tuple) else (idx,)
for i in range(nb_r):
rs = slice(i * block, (i + 1) * block)
for j in range(nb_c):
cs = slice(j * block, (j + 1) * block)
blk = w_f32[(*sel, rs, cs)] # [128, 128]
amax = torch.amax(torch.abs(blk))
s = (
(amax / max_fp8)
if amax > 0
else torch.tensor(1.0, device=w_bf16.device)
)
q = (blk / s).to(torch.float8_e4m3fn)
w_fp8[(*sel, rs, cs)] = q
scales[(*sel, i, j)] = s
return w_fp8, scales
# Reshape to [..., nb_r, block, nb_c, block]
new_shape = (*prefix, nb_r, block, nb_c, block)
w_blocks = w_f32.view(new_shape)
# Compute amax for each block
amax = torch.amax(torch.abs(w_blocks), dim=(-1, -3), keepdim=False) # shape: [..., nb_r, nb_c]
# Compute scale for each block
scales = torch.where(
amax > 0,
amax / max_fp8,
torch.ones_like(amax, device=w_bf16.device)
) # shape: [..., nb_r, nb_c]
# Expand scales for broadcasting
scales_expanded = scales.unsqueeze(-1).unsqueeze(-3) # shape: [..., nb_r, 1, nb_c, 1]
# Quantize blocks
w_fp8_blocks = (w_blocks / scales_expanded).to(torch.float8_e4m3fn)
# Reshape back to original shape
w_fp8 = w_fp8_blocks.view(*prefix, R, C)

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,574 @@
import pytest
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing copyright header. All other test files in this repository include the Apache 2.0 copyright header at the top of the file. Please add the standard copyright header to maintain consistency with the codebase.

Copilot uses AI. Check for mistakes.
Comment on lines +435 to +439
N_GROUP = routing_config["n_groups"] # deepseek v3: 8
TOPK_GROUP = routing_config["top_k_groups"] # deepseek v3: 4

if local_expert_offset + E_LOCAL > E_GLOBAL:
pytest.skip(
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function returns True or False but the test framework expects assertions or exceptions. The early returns with boolean values (lines 435, 439) are not appropriate for pytest tests. Instead, use pytest.skip() for conditions that should skip the test.

Suggested change
N_GROUP = routing_config["n_groups"] # deepseek v3: 8
TOPK_GROUP = routing_config["top_k_groups"] # deepseek v3: 4
if local_expert_offset + E_LOCAL > E_GLOBAL:
pytest.skip(
pytest.skip("CUDA not available, skipping test.")
if trtllm_fp8_block_scale_moe is None:
print("WARNING: flashinfer fused_moe kernel not available.")
pytest.skip("flashinfer fused_moe kernel not available.")

Copilot uses AI. Check for mistakes.
N_GROUP,
TOPK_GROUP,
I,
inputs["local_expert_offset"],
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented-out code should be removed. The line calculating tile_tokens_dim is commented out but not used. If this is intentional for future use, add a comment explaining why it's kept; otherwise, remove it to keep the codebase clean.

Suggested change
inputs["local_expert_offset"],

Copilot uses AI. Check for mistakes.
# Generate random but consistent inputs
print("Generating random inputs")
inputs = generate_random_inputs_moe(
seq_len,
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment contains a typo: "todo(yingyi)" should follow proper TODO format. The standard format is # TODO(yingyi): with uppercase TODO, parentheses, and a colon.

Suggested change
seq_len,
E_LOCAL = 32 # TODO(yingyi): default to tp8 for now, update later

Copilot uses AI. Check for mistakes.
Comment on lines +365 to +407
),
pytest.param(
{
"num_experts": 256,
"top_k": 8,
"padding": 8,
"n_groups": 8,
"top_k_groups": 4,
"routed_scaling": 2.5,
"compatible_intermediate_size": [512, 1024, 2048],
"enable_autotune": True,
},
id="DSv3",
),
pytest.param(
{
"num_experts": 72,
"top_k": 6,
"padding": 8,
"n_groups": 1,
"top_k_groups": 1,
"routed_scaling": 2.5,
"compatible_intermediate_size": [384, 768],
"enable_autotune": False,
},
id="DSLite",
),
],
)
@pytest.mark.parametrize("enable_pdl", [True, False])
def test_correctness_dpsk_fp8_fused_moe(
seq_len,
local_expert_offset,
use_bias,
intermediate_size,
routing_config,
enable_pdl,
atol: float = 1e-1,
rtol: float = 2e-1,
percent: float = 0.85,
):
compatible_intermediate_size = routing_config["compatible_intermediate_size"]
if intermediate_size not in compatible_intermediate_size:
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The routing config dictionary includes a 'padding' field (lines 372, 385, 398) that is never used in the test function. Either use this parameter in the test logic or remove it from the configuration to avoid confusion.

Copilot uses AI. Check for mistakes.
gemm1_weights_scale=inputs["gemm1_weights_scale"],
gemm2_weights=inputs["gemm2_weights"],
gemm2_weights_scale=inputs["gemm2_weights_scale"],
local_expert_offset=inputs["local_expert_offset"],
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The routed_scaling_factor is hardcoded to 2.5 here, but it should be using routing_config["routed_scaling"] to be consistent with the test parametrization. This inconsistency could lead to testing with the wrong parameter values.

Suggested change
local_expert_offset=inputs["local_expert_offset"],
routed_scaling_factor=routing_config["routed_scaling"],

Copilot uses AI. Check for mistakes.
Comment on lines 28 to 41
• FP8 block-scale dequantization: float ≈ fp8 * scale
• DeepSeek-V3 no-aux routing:
s = sigmoid(logits)
s_with_bias = s + bias
group by n_group=8; per group take top-2 sum → pick topk_group=4 groups
on the kept groups, take global top_k=8 experts
combine with weights derived from s (without bias), normalized and
scaled by routed_scaling_factor
• Local computation:
only experts in [local_expert_offset, local_expert_offset + E_local) are
computed on this rank (GEMM1 → SwiGLU → GEMM2), then per-token weighted
accumulation.
"""

Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring incorrectly uses bullet points with special characters () which are not standard in Python docstrings. Use standard ASCII characters like - or * for bullet points to ensure proper rendering in documentation tools and editors.

Copilot uses AI. Check for mistakes.
Comment on lines 410 to 420
)

print("\n" + "=" * 70)
print(
f"Testing MoE FP8 Block-Scale: seq_len={seq_len}, offset={local_expert_offset}, use_bias={use_bias}"
)
print("=" * 70)

if not torch.cuda.is_available():
print("WARNING: CUDA not available, skipping test.")
return True
Copy link

Copilot AI Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot uses AI. Check for mistakes.
@yyihuang yyihuang requested a review from jiahanc November 14, 2025 23:06
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
tests/moe/test_dpsk_fused_moe_fp8.py (4)

229-269: Make FP8 2D quantization branch more idiomatic and avoid per-block tensor creation

The scalar path in _fp8_block_quant_2d currently relies on if amax > 0 with a tensor and allocates a new torch.tensor(1.0, ...) for every zero block. That works but is a bit non-idiomatic and slightly wasteful.

You can simplify and make the intent clearer:

-                amax = torch.amax(torch.abs(blk))
-                s = (
-                    (amax / max_fp8)
-                    if amax > 0
-                    else torch.tensor(1.0, device=w_bf16.device)
-                )
+                amax = torch.amax(torch.abs(blk))
+                if amax.item() > 0:
+                    s = amax / max_fp8
+                else:
+                    s = torch.tensor(1.0, device=w_bf16.device)

(or keep s as a plain Python float 1.0 and rely on broadcasting). This avoids depending on bool(tensor) semantics and cuts a small allocation.


289-342: Reuse routing_config["routed_scaling"] instead of hardcoding 2.5

generate_random_inputs_moe already takes routed_scaling_factor, but the test always passes a literal 2.5 even though routing_config carries a routed_scaling field. This duplication can easily drift if configs change.

You can wire the config through to the generator:

-    inputs = generate_random_inputs_moe(
+    inputs = generate_random_inputs_moe(
         seq_len,
         num_experts_global=E_GLOBAL,
         num_local_experts=E_LOCAL,
         hidden_size=H,
         intermediate_size=I,
         use_bias=use_bias,
         local_expert_offset=local_expert_offset,
-        routed_scaling_factor=2.5,
+        routed_scaling_factor=routing_config["routed_scaling"],
         device=device,
     )

This keeps the reference path and fused kernel in sync with a single source of truth for the routed scaling factor.

Also applies to: 458-470


444-452: Be aware of high GPU memory footprint for DS-V3-sized weights

With E_LOCAL = 32, H = 7168, and intermediate_size up to 2048, the weight tensors (w13_*, w2_*) plus their FP32 copies and FP8 variants are on the order of multiple GB per test invocation. That’s realistic for DS‑V3 but can stress smaller CI GPUs when combined with the parameter sweep.

If you hit OOMs in CI, consider:

  • Adding a smaller “smoke” configuration (reduced H, E_LOCAL, or intermediate_size) behind a flag/marker, or
  • Restricting some parameter combinations for CI while keeping the full geometry for dedicated perf/correctness runs.

No change is strictly required if your CI hardware is sized for this.

Also applies to: 318-328


45-45: Optional: rename I and O to satisfy Ruff E741 and improve readability

Ruff is flagging the single-letter variables I (intermediate size) and O (expert output) as ambiguous. While fine functionally, renaming them slightly improves readability and resolves E741:

-    I = intermediate_size  # deepseek v3: 2048
+    intermediate_dim = intermediate_size
@@
-        O = C.matmul(W2_e.t())  # [Tk, H]
+        expert_output = C.matmul(W2_e.t())  # [Tk, H]
@@
-    T, H, I = seq_len, hidden_size, intermediate_size
+    T, H, intermediate_dim = seq_len, hidden_size, intermediate_size
@@
-    I = intermediate_size  # deepseek v3: 2048
+    intermediate_dim = intermediate_size  # deepseek v3: 2048

You’d then update uses of I/O accordingly. This is cosmetic and can be deferred if you’re not enforcing E741.

Also applies to: 186-186, 302-302, 448-448

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 54101e9 and 7f24824.

📒 Files selected for processing (1)
  • tests/moe/test_dpsk_fused_moe_fp8.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_dpsk_fused_moe_fp8.py (5)
flashinfer/fused_moe/core.py (1)
  • WeightLayout (162-169)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
  • routing_logits (147-155)
  • routing_bias (158-164)
include/flashinfer/trtllm/fused_moe/runner.h (6)
  • local_expert_offset (276-276)
  • hidden_size (265-265)
  • intermediate_size (275-275)
  • n_group (271-271)
  • topk_group (273-273)
  • num_experts (263-263)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🪛 Ruff (0.14.4)
tests/moe/test_dpsk_fused_moe_fp8.py

45-45: Ambiguous variable name: I

(E741)


186-186: Ambiguous variable name: O

(E741)


302-302: Ambiguous variable name: I

(E741)


448-448: Ambiguous variable name: I

(E741)

🔇 Additional comments (1)
tests/moe/test_dpsk_fused_moe_fp8.py (1)

522-570: Comparison logic and hit-ratio assertion look solid

The combination of full-tensor stats (abs/rel diff, cosine similarity, MSE), a strict allclose check for diagnostics, and a relaxed per-element hit ratio for assertion is a good fit for FP8 numerical noise. The tolerance and percent threshold are explicit and easy to tune if kernels evolve.

No changes needed here.

Comment on lines +66 to +80
assert hidden_states.shape == (T, H)
assert hidden_states_scale.shape == (num_hidden_blocks, T)
assert gemm1_weights.shape == (E_local, 2 * I, H)
assert gemm1_weights_scale.shape == (
E_local,
num_gemm1_out_blocks,
num_hidden_blocks,
)
assert gemm2_weights.shape == (E_local, H, I)
assert gemm2_weights_scale.shape == (
E_local,
num_hidden_blocks,
num_intermediate_blocks,
)
assert routing_bias.shape[-1] == E_global
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

Tighten routing/bias shape checks and guard group_size divisibility

The routing logic assumes E_global is divisible by N_GROUP and that routing_bias is 1D, but those assumptions are not enforced. A misconfigured routing_config would currently fail with a less-informative .view error, and the bias assertion is looser than the kernel’s 1D contract.

Consider making these checks explicit:

@@
-    E_global = routing_logits.shape[1]
-    T = routing_logits.shape[0]
+    E_global = routing_logits.shape[1]
+    T = routing_logits.shape[0]
@@
-    assert routing_bias.shape[-1] == E_global
+    assert routing_bias.ndim == 1 and routing_bias.shape[0] == E_global
@@
-    group_size = E_global // N_GROUP  # 32
-    s_wb_grouped = s_with_bias.view(T, N_GROUP, group_size)  # [T, 8, 32]
+    assert (
+        E_global % N_GROUP == 0
+    ), "E_global must be divisible by N_GROUP for group routing"
+    group_size = E_global // N_GROUP  # 32
+    s_wb_grouped = s_with_bias.view(T, N_GROUP, group_size)  # [T, N_GROUP, group_size]

This keeps the reference implementation closer to the fused kernel’s expectations and fails fast on bad configs.

Also applies to: 120-137


🏁 Script executed:

#!/bin/bash
# Examine the test file to verify review comment concerns
echo "=== Lines 66-80 (shown assertions) ===" 
sed -n '66,80p' tests/moe/test_dpsk_fused_moe_fp8.py

echo -e "\n=== Lines 110-150 (routing logic area) ===" 
sed -n '110,150p' tests/moe/test_dpsk_fused_moe_fp8.py

echo -e "\n=== Lines 115-145 (find run function, routing checks) ===" 
sed -n '115,145p' tests/moe/test_dpsk_fused_moe_fp8.py

Length of output: 3339


🏁 Script executed:

#!/bin/bash
# Search for routing_bias, E_global, N_GROUP in the file
echo "=== Routing bias assertions ===" 
rg -n "routing_bias" tests/moe/test_dpsk_fused_moe_fp8.py -A2 -B2

echo -e "\n=== E_global and N_GROUP usage ===" 
rg -n "E_global|N_GROUP" tests/moe/test_dpsk_fused_moe_fp8.py -B1 -A1

echo -e "\n=== Group size calculation ===" 
rg -n "group_size" tests/moe/test_dpsk_fused_moe_fp8.py -B2 -A2

Length of output: 3943


🏁 Script executed:

#!/bin/bash
# Look at function containing the routing logic to understand context
echo "=== Search for 'run' function definition ===" 
rg -n "def run" tests/moe/test_dpsk_fused_moe_fp8.py -A50 | head -80

Length of output: 1896


Add explicit guards for routing_bias 1D constraint and E_global divisibility

The routing_bias shape check (line 80) only validates shape[-1], allowing 2D or higher tensors through; the reshape at line 113 silently flattens them. Additionally, the group_size calculation at line 120 has no guard that E_global % N_GROUP == 0, so misconfigured routing would fail with an uninformative view/reshape error.

Add explicit checks to fail fast with clear messages:

-    assert routing_bias.shape[-1] == E_global
+    assert routing_bias.ndim == 1 and routing_bias.shape[0] == E_global, \
+        f"routing_bias must be 1D with shape [E_global={E_global}], got shape {routing_bias.shape}"
+    assert E_global % N_GROUP == 0, \
+        f"E_global ({E_global}) must be divisible by N_GROUP ({N_GROUP}) for group routing"
     group_size = E_global // N_GROUP  # 32

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tests/moe/test_dpsk_fused_moe_fp8.py around lines 66 to 80 (with related
operations at ~113 and ~120), the current assert only checks
routing_bias.shape[-1] allowing multi-dimensional tensors and later reshaping
hides that, and there is no guard that E_global is divisible by N_GROUP before
computing group_size; add an explicit check that routing_bias.ndim == 1 and
routing_bias.shape[0] == E_global (raise a clear AssertionError mentioning
expected 1D length E_global), and add a precondition asserting E_global %
N_GROUP == 0 (raise a clear AssertionError indicating E_global must be divisible
by N_GROUP) before computing group_size so failures are immediate and
informative.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (8)
tests/moe/test_dpsk_fused_moe_fp8.py (8)

27-40: Normalize docstring bullets to plain ASCII for tooling compatibility

The docstring uses bullet characters; many doc/tools expect simple - or * ASCII bullets. Consider rewriting the bullets with - to match common Python docstring style.


228-268: Consider vectorizing 2D FP8 block quantization to avoid triple nested loops

_fp8_block_quant_2d currently iterates over prefix × nb_r × nb_c with Python loops, which can be slow for large expert/hidden dims.

If performance of test-side quantization becomes an issue, you could reshape to an explicit block grid and compute scales in a single vectorized pass, e.g.:

-    it = np.ndindex(*prefix) if prefix else [()]
-    for idx in it:
-        sel = idx if isinstance(idx, tuple) else (idx,)
-        for i in range(nb_r):
-            rs = slice(i * block, (i + 1) * block)
-            for j in range(nb_c):
-                cs = slice(j * block, (j + 1) * block)
-                blk = w_f32[(*sel, rs, cs)]
-                ...
+    # Example sketch: reshape to [..., nb_r, block, nb_c, block] and
+    # compute amax/scales with torch.amax over the block dims, then
+    # broadcast scales back and quantize in one or a few tensor ops.

Not strictly required for correctness, but worth considering if these tests are run at large scale.


353-392: Unused padding field in routing_config may confuse future readers

Each routing_config dict defines a "padding" entry, but the value is never used in the test logic or in the call to trtllm_fp8_block_scale_moe.

If padding is not needed here, consider removing that key from the test configs (and from the __main__ example at the bottom) or add a short comment explaining that it’s reserved for future use.


430-437: Clarify TODO style and constant E_LOCAL comment

Line 431 uses a non-standard TODO style and hardcodes the number of local experts:

E_LOCAL = 32  # todo(yingyi): default to tp8 for now, update later

For consistency with common conventions and past comments, consider:

-    E_LOCAL = 32  # todo(yingyi): default to tp8 for now, update later
+    E_LOCAL = 32  # TODO(yingyi): default to tp8 for now, update later

If there is a future plan to derive E_LOCAL from topology (e.g., TP size), you might also expand the comment slightly to mention that.


443-455: Wire routed_scaling_factor from routing_config instead of hardcoding 2.5

The test currently hardcodes routed_scaling_factor=2.5 when generating inputs:

inputs = generate_random_inputs_moe(
    ...,
    routed_scaling_factor=2.5,
    device=device,
)

Even though all current routing_config entries use 2.5, this means the "routed_scaling" field in the config is effectively ignored.

To keep the test consistent with the configuration and future-proof against config changes, pass the value through:

-        routed_scaling_factor=2.5,
+        routed_scaling_factor=routing_config["routed_scaling"],

Because both the reference run() and the fused kernel use inputs["routed_scaling_factor"], this single change propagates the config value everywhere it’s needed.


1-6: Add standard Apache 2.0 header to match other test files

This test file is missing the usual Apache 2.0 copyright/license header used elsewhere in the repo. Please add the standard header for consistency and compliance.


48-80: Tighten routing preconditions: 1D bias and E_global divisibility by N_GROUP

The core routing invariants are currently only partially enforced:

  • Line 79: routing_bias.shape[-1] == E_global allows higher-rank tensors; the fused kernel expects a 1D bias of length E_global.
  • Lines 118–120: group_size = E_global // N_GROUP assumes E_global % N_GROUP == 0 but does not check it, so misconfigurations will fail later with opaque view/shape errors.

Strengthening these asserts will make the reference more robust and better aligned with the kernel’s shape checks:

-    assert routing_bias.shape[-1] == E_global
+    assert (
+        routing_bias.ndim == 1 and routing_bias.shape[0] == E_global
+    ), f"routing_bias must be 1D with length {E_global}, got shape {routing_bias.shape}"

@@
-    group_size = E_global // N_GROUP  # 32
+    assert (
+        E_global % N_GROUP == 0
+    ), f"E_global ({E_global}) must be divisible by N_GROUP ({N_GROUP}) for group routing"
+    group_size = E_global // N_GROUP  # 32

This fails fast with clear messages if routing_config is mis-specified.

Also applies to: 118-121


418-425: Use pytest.skip instead of returning booleans for unavailable CUDA/kernel

The early exits currently return True/False:

if not torch.cuda.is_available():
    print("WARNING: CUDA not available, skipping test.")
    return True

if trtllm_fp8_block_scale_moe is None:
    print("WARNING: flashinfer fused_moe kernel not available.")
    return False

Pytest ignores test return values, so these cases are reported as passing tests rather than explicitly skipped, which can hide misconfigured environments.

Switch to pytest.skip() to mark them as skipped:

-    if not torch.cuda.is_available():
-        print("WARNING: CUDA not available, skipping test.")
-        return True
-
-    if trtllm_fp8_block_scale_moe is None:
-        print("WARNING: flashinfer fused_moe kernel not available.")
-        return False
+    if not torch.cuda.is_available():
+        pytest.skip("CUDA not available, skipping fused MoE FP8 test.")
+
+    if trtllm_fp8_block_scale_moe is None:
+        pytest.skip("flashinfer fused_moe FP8 kernel not available.")

This also removes the mixed return/no-return paths from the test function.

In pytest, what happens to non-None return values from test functions, and what is the recommended pattern for conditionally skipping tests (e.g., when CUDA or an optional kernel is unavailable)?
🧹 Nitpick comments (1)
tests/moe/test_dpsk_fused_moe_fp8.py (1)

557-574: __main__ entry is handy; just be aware of interaction with pytest.skip

The if __name__ == "__main__": block is useful for ad-hoc runs, but once the early CUDA/kernel checks use pytest.skip(), calling test_correctness_dpsk_fp8_fused_moe directly without pytest will raise a skip exception if those conditions fail.

That’s acceptable for internal tooling, but it’s worth keeping in mind if you expect this script to be run outside pytest.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7f24824 and 9043383.

📒 Files selected for processing (1)
  • tests/moe/test_dpsk_fused_moe_fp8.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_dpsk_fused_moe_fp8.py (5)
flashinfer/fused_moe/core.py (1)
  • WeightLayout (162-169)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
  • routing_logits (147-155)
  • routing_bias (158-164)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • local_expert_offset (276-276)
  • hidden_size (265-265)
  • intermediate_size (275-275)
  • n_group (271-271)
  • topk_group (273-273)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🪛 Ruff (0.14.4)
tests/moe/test_dpsk_fused_moe_fp8.py

44-44: Ambiguous variable name: I

(E741)


185-185: Ambiguous variable name: O

(E741)


287-287: Ambiguous variable name: I

(E741)


433-433: Ambiguous variable name: I

(E741)

🔇 Additional comments (4)
tests/moe/test_dpsk_fused_moe_fp8.py (4)

197-225: FP8 1D block quantization logic looks correct

The _fp8_block_quant_1d helper cleanly enforces [T, H] input, block divisibility, and uses per-(token, block) amax to derive dequant scales consistent with run()’s dequantization path. No changes needed here.


274-327: Random input generator is coherent with run(), no issues

generate_random_inputs_moe produces routing logits/bias, FP8-quantized activations and weights, and corresponding scales with shapes that match the assertions and dequantization logic in run(). The hidden/intermediate size multiples-of-128 guard is also appropriate. No changes needed here.


395-405: Parametrized test interface and tolerances look reasonable

The test parametrization (sequence lengths, local offsets, bias flag, intermediate sizes, routing configs, enable_pdl) and relaxed (atol, rtol, percent) thresholds look appropriate for FP8 block-scale behavior and should provide good coverage across DeepSeek variants.


457-554: Comparison metrics and hit-ratio check are well-structured

The diagnostics (abs/rel diff statistics, cosine similarity, MSE, top-5 largest errors) and the final hit-ratio assertion relative to (atol, rtol, percent) form a solid, informative correctness check for FP8 behavior. This section looks good as-is.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (5)
tests/moe/test_dpsk_fused_moe_fp8.py (5)

1-1: Add the Apache 2.0 copyright header.

All other test files in this repository include the standard Apache 2.0 copyright header. Please add it at the top of the file to maintain consistency with the codebase.


79-79: Tighten routing_bias shape constraint to enforce 1D.

The current assertion only checks shape[-1], allowing multi-dimensional tensors (e.g., [1, 256]) to pass. The kernel contract (as shown in csrc/trtllm_fused_moe_kernel_launcher.cu:158-163) requires routing_bias to be strictly 1D. The reshape at line 112 silently flattens multi-dim tensors, hiding the issue.

Based on relevant code snippets

Apply this diff to enforce the 1D constraint:

-    assert routing_bias.shape[-1] == E_global
+    assert routing_bias.ndim == 1 and routing_bias.shape[0] == E_global, \
+        f"routing_bias must be 1D with shape [E_global={E_global}], got shape {routing_bias.shape}"

120-120: Add divisibility guard for E_global % N_GROUP.

The group_size calculation assumes E_global is divisible by N_GROUP, but there's no guard. Misconfigured routing would fail at the .view() operation with an uninformative reshape error.

Add this check before the calculation:

+    assert E_global % N_GROUP == 0, \
+        f"E_global ({E_global}) must be divisible by N_GROUP ({N_GROUP}) for group routing"
     group_size = E_global // N_GROUP  # 32

350-393: Remove or document the unused 'padding' field in routing configs.

Each routing configuration dictionary includes a 'padding' field (lines 357, 370, 383) that is never used in the test function. Either remove it to avoid confusion, or add a comment explaining it's reserved for future use.


423-423: Use standard TODO format.

The comment uses lowercase "todo" which doesn't follow the standard convention. Use uppercase "TODO" with a colon for better tool support and consistency.

Apply this diff:

-    E_LOCAL = 32  # todo(yingyi): default to tp8 for now, update later
+    E_LOCAL = 32  # TODO(yingyi): default to tp8 for now, update later
🧹 Nitpick comments (3)
tests/moe/test_dpsk_fused_moe_fp8.py (3)

216-224: Consider vectorizing block-wise quantization for better performance.

The nested loop processes blocks individually, which can be slow for large tensors. Since this is test code, performance impact is limited, but vectorization would improve both speed and readability.

Apply this diff to vectorize the operation:

-    for j in range(nb):
-        sl = slice(j * block, (j + 1) * block)
-        blk = x_f32[:, sl]  # [T, 128]
-        amax = torch.amax(torch.abs(blk), dim=1)  # [T]
-        # dequant scale s = amax / max_fp8  (float ≈ fp8 * s)
-        s = torch.where(amax > 0, amax / max_fp8, torch.ones_like(amax))
-        q = (blk / s.unsqueeze(1)).to(torch.float8_e4m3fn)  # quantization
-        x_fp8[:, sl] = q
-        scales[:, j] = s
+    # Reshape to [T, nb, block]
+    x_blocks = x_f32.view(T, nb, block)
+    # Compute amax for each block: [T, nb]
+    amax = torch.amax(torch.abs(x_blocks), dim=2)
+    # Compute scales: [T, nb]
+    scales = torch.where(amax > 0, amax / max_fp8, torch.ones_like(amax))
+    # Quantize blocks: [T, nb, block]
+    x_fp8_blocks = (x_blocks / scales.unsqueeze(2)).to(torch.float8_e4m3fn)
+    # Reshape back: [T, H]
+    x_fp8 = x_fp8_blocks.view(T, H)

252-267: Consider vectorizing 2D block-wise quantization for better performance.

The triple-nested loop processes each block individually. Vectorization would improve performance, though the implementation is more complex due to arbitrary prefix dimensions and 2D blocking.

Apply this diff to vectorize:

-    it = np.ndindex(*prefix) if prefix else [()]
-    for idx in it:
-        sel = idx if isinstance(idx, tuple) else (idx,)
-        for i in range(nb_r):
-            rs = slice(i * block, (i + 1) * block)
-            for j in range(nb_c):
-                cs = slice(j * block, (j + 1) * block)
-                blk = w_f32[(*sel, rs, cs)]  # [128, 128]
-                amax = torch.amax(torch.abs(blk))
-                s = (
-                    (amax / max_fp8)
-                    if amax > 0
-                    else torch.tensor(1.0, device=w_bf16.device)
-                )
-                q = (blk / s).to(torch.float8_e4m3fn)
-                w_fp8[(*sel, rs, cs)] = q
-                scales[(*sel, i, j)] = s
+    # Reshape to [..., nb_r, block, nb_c, block]
+    new_shape = (*prefix, nb_r, block, nb_c, block)
+    w_blocks = w_f32.view(new_shape)
+    # Compute amax for each block: [..., nb_r, nb_c]
+    amax = torch.amax(torch.abs(w_blocks), dim=(-1, -3))
+    # Compute scales: [..., nb_r, nb_c]
+    scales = torch.where(
+        amax > 0,
+        amax / max_fp8,
+        torch.ones_like(amax)
+    )
+    # Expand scales for broadcasting: [..., nb_r, 1, nb_c, 1]
+    scales_expanded = scales.unsqueeze(-1).unsqueeze(-3)
+    # Quantize blocks
+    w_fp8_blocks = (w_blocks / scales_expanded).to(torch.float8_e4m3fn)
+    # Reshape back: [..., R, C]
+    w_fp8 = w_fp8_blocks.view(*prefix, R, C)

44-44: Consider using more descriptive variable names.

Static analysis flags I as ambiguous (can be confused with 1 or l). While common in ML contexts for intermediate_size, consider using intermediate_size or inter_size directly for better readability.

Similarly, at line 185, O could be renamed to expert_out or output_e for clarity.

Also applies to: 287-287, 425-425

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9043383 and 2471a8a.

📒 Files selected for processing (1)
  • tests/moe/test_dpsk_fused_moe_fp8.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_dpsk_fused_moe_fp8.py (4)
flashinfer/fused_moe/core.py (1)
  • WeightLayout (162-169)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
  • routing_logits (147-155)
  • routing_bias (158-164)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • local_expert_offset (276-276)
  • hidden_size (265-265)
  • intermediate_size (275-275)
  • n_group (271-271)
  • topk_group (273-273)
🪛 Ruff (0.14.4)
tests/moe/test_dpsk_fused_moe_fp8.py

44-44: Ambiguous variable name: I

(E741)


185-185: Ambiguous variable name: O

(E741)


287-287: Ambiguous variable name: I

(E741)


425-425: Ambiguous variable name: I

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
tests/moe/test_dpsk_fused_moe_fp8.py (9)

8-40: LGTM!

The function signature and docstring clearly document the FP8 block-scale dequantization workflow and DeepSeek-V3 no-aux routing logic.


83-108: LGTM!

The FP8 block-scale dequantization logic correctly handles scale expansion for both 1D activations and 2D weights using block-wise scaling.


110-150: LGTM!

The DeepSeek-V3 no-aux routing logic correctly implements:

  • Sigmoid activation with bias for group selection
  • Group-wise top-2 scoring and topk_group filtering
  • Global top-k expert selection within kept groups
  • Weight normalization using sigmoid scores (without bias) and routed_scaling_factor

152-191: LGTM!

The local expert computation correctly:

  • Filters tokens per expert using routing top-k results
  • Executes the MoE forward pass: GEMM1 → SwiGLU → GEMM2
  • Accumulates outputs with per-token routing weights

274-328: LGTM!

The input generator correctly creates random routing inputs, FP8-quantized activations and weights with appropriate block-wise scales for testing.


412-416: LGTM!

The test correctly uses pytest.skip() to handle missing CUDA or kernel support, ensuring CI properly reports these conditions instead of silently passing.


435-493: LGTM!

The test execution correctly:

  • Generates consistent random inputs
  • Runs the reference implementation
  • Invokes the FlashInfer fused kernel with autotune when enabled
  • Passes matching parameters to both implementations for fair comparison

495-543: LGTM!

The comparison logic is comprehensive, providing:

  • Multiple error metrics (absolute, relative, cosine similarity, MSE)
  • Detailed diagnostic output including top-5 error locations
  • Flexible tolerance checking via hit ratio (85% threshold) appropriate for FP8 quantization

546-563: LGTM!

The main guard enables quick manual testing with a representative configuration (DSv3, seq_len=1, enable_pdl=True).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (7)
tests/moe/test_dpsk_fused_moe_fp8.py (7)

1-1: Add copyright header for consistency.

As noted in previous reviews, this file is missing the Apache 2.0 copyright header present in other test files in the repository. Please add it to maintain codebase consistency.


41-43: Ambiguous variable name flagged by linter.

Static analysis flags I as an ambiguous variable name. While I for intermediate_size is conventional in MoE contexts, consider using a more descriptive name like inter_size if it improves readability for your team.


78-78: Consider tightening routing_bias shape validation.

The current check routing_bias.shape[-1] == E_global allows multi-dimensional tensors, though line 111 reshapes to 1D. As suggested in previous reviews, explicitly enforcing 1D would fail fast with a clearer error message.

Based on learnings


118-119: Add divisibility guard for group_size calculation.

As noted in previous reviews, there's no explicit check that E_global % N_GROUP == 0 before computing group_size. Adding an assertion would provide a clearer error message if routing is misconfigured.

Based on learnings


196-224: Block quantization logic is correct; consider vectorization for performance.

The loop-based block quantization is straightforward and correct. As noted in previous reviews, vectorizing using reshape/view operations could improve performance for large tensors, but for test code this optimization is optional.

Based on learnings


430-430: Use standard TODO format.

The comment todo(yingyi) should follow the standard format: # TODO(yingyi): (uppercase, with colon).


364-364: Remove unused 'padding' field from routing configurations.

As noted in previous reviews, the 'padding' field appears in all three routing configs but is never used in the test logic. Consider removing it to avoid confusion, or document why it's included.

Based on learnings

Also applies to: 377-377, 390-390

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2471a8a and 541b80b.

📒 Files selected for processing (1)
  • tests/moe/test_dpsk_fused_moe_fp8.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/moe/test_dpsk_fused_moe_fp8.py (5)
flashinfer/fused_moe/core.py (1)
  • WeightLayout (162-169)
flashinfer/autotuner.py (1)
  • autotune (251-262)
csrc/trtllm_fused_moe_kernel_launcher.cu (2)
  • routing_logits (147-155)
  • routing_bias (158-164)
include/flashinfer/trtllm/fused_moe/runner.h (5)
  • local_expert_offset (276-276)
  • hidden_size (265-265)
  • intermediate_size (275-275)
  • n_group (271-271)
  • topk_group (273-273)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
🪛 Ruff (0.14.4)
tests/moe/test_dpsk_fused_moe_fp8.py

43-43: Ambiguous variable name: I

(E741)


184-184: Ambiguous variable name: O

(E741)


250-255: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


294-294: Ambiguous variable name: I

(E741)


432-432: Ambiguous variable name: I

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (9)
tests/moe/test_dpsk_fused_moe_fp8.py (9)

26-39: Consider using standard ASCII bullet points in docstring.

The docstring uses - for bullet points, which is correct. However, ensure this renders properly in all documentation tools.


82-108: FP8 dequantization logic is well-structured.

The block-scale expansion and dequantization for activations and weights follow a consistent pattern and correctly implement the FP8 block-scale semantics described in the docstring.


109-149: Routing logic correctly implements DeepSeek-V3 no-aux routing.

The implementation properly handles:

  • Sigmoid with bias for group/expert selection
  • Group-based top-k routing (top-2 per group → top-4 groups → top-8 global)
  • Weight computation using sigmoid without bias, as documented

151-190: Local expert computation correctly implements GEMM1→SwiGLU→GEMM2.

The per-expert loop, token selection, and weighted accumulation are correctly implemented for a reference path. The SwiGLU activation silu(x2) * x1 matches the standard formulation.

Note: Static analysis flags O as an ambiguous variable name (line 184), but it's clear in context as "output".


227-275: 2D block quantization uses efficient vectorized operations.

This implementation correctly uses vectorized operations to process all blocks in parallel, avoiding the nested loops present in _fp8_block_quant_1d. The reshape/permute logic correctly handles arbitrary prefix dimensions.


281-334: Input generator is well-structured with keyword-only arguments.

The function cleanly separates random generation and FP8 quantization, and the use of keyword-only arguments (after *) prevents accidental positional argument misuse.


443-474: Test execution correctly uses routing_config parameters.

The input generation (line 451) correctly uses routing_config["routed_scaling"], and both the reference (line 466) and fused kernel (line 494) calls consistently use inputs["routed_scaling_factor"]. The previous concern about hardcoded values appears to have been addressed.


502-550: Comprehensive comparison with multiple diagnostic metrics.

The test provides excellent diagnostic output including absolute/relative differences, cosine similarity, MSE, top error locations, and hit ratio. The flexible hit ratio check (default 85%) is appropriate for FP8 numerical precision.


553-570: Main guard enables convenient manual testing.

The main guard allows direct execution of the test file with a representative configuration, which is helpful for development and debugging.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 16, 2025

/bot run

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, cc @IwakuraRein for viz.

@flashinfer-bot
Copy link
Collaborator

GitLab MR !140 has been created, and the CI pipeline #38575770 is currently running. I'll report back once the pipeline job completes.

Copy link
Collaborator

@jiahanc jiahanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! Can you remove the duplicate test in test_trtllm_gen_fused_moe.py ? (test_deepseekv3_routing)

@yyihuang
Copy link
Collaborator Author

Thanks for the contribution! Can you remove the duplicate test in test_trtllm_gen_fused_moe.py ? (test_deepseekv3_routing)

We could remove the tests in test_trtllm_gen_fused_moe.py once we have all dpskv3 tests refactored.

@yyihuang yyihuang merged commit 4ddf71d into flashinfer-ai:main Nov 16, 2025
4 checks passed
qsang-nv pushed a commit to qsang-nv/flashinfer that referenced this pull request Nov 18, 2025
<!-- .github/pull_request_template.md -->

## 📌 Description

Refactor fused_moe test.

Split test on model+precision.

Part [1]:
- test deepseek (kimi, lite) fp8 block-scaled fused moe
- default TP8
- PDL enabled
- MajorK weight layout
- higher tolerance and matching percentage

Next Part [2]:
- add BlockMajorK weight layout

Next Part [x]:
- Per Tensor FP8 MoE,  FP4MoE

later:
- refactor llama4, topk?, renormalize? routing tests

## 🔍 Related Issues

<!-- Link any related issues here -->

## 🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull
request, please make sure the following items are complete.

### ✅ Pre-commit Checks

- [x] I have installed `pre-commit` by running `pip install pre-commit`
(or used your preferred method).
- [x] I have installed the hooks with `pre-commit install`.
- [x] I have run the hooks manually with `pre-commit run --all-files`
and fixed any reported issues.

> If you are unsure about how to set up `pre-commit`, see [the
pre-commit documentation](https://pre-commit.com/).

## 🧪 Tests

- [x] Tests have been added or updated as needed.
- [x] All tests are passing (`unittest`, etc.).

## Reviewer Notes

<!-- Optional: anything you'd like reviewers to focus on, concerns, etc.
-->


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Tests**
* Added a comprehensive FP8 block-scale fused Mixture-of-Experts test
validating end-to-end correctness across many routing, expert and
precision configurations. Includes randomized inputs,
per-token/per-expert workflows, extensive parameterizations, diagnostic
statistics, autotune-path checks, and a minimal sanity run.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants