Skip to content

[WIP] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE#3041

Open
yhyang201 wants to merge 2 commits intoflashinfer-ai:mainfrom
yhyang201:feat/cute-dsl-moe-multi-b-weights
Open

[WIP] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE#3041
yhyang201 wants to merge 2 commits intoflashinfer-ai:mainfrom
yhyang201:feat/cute-dsl-moe-multi-b-weights

Conversation

@yhyang201
Copy link
Copy Markdown

@yhyang201 yhyang201 commented Apr 13, 2026

Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as either a single tensor or a list of up to 4 tensors split along the expert dimension. The compiled kernel is specialized per multi-B config via b_tensor_l_sizes, with kernel-side branching selecting the right B tensor from the runtime expert index.

Also adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline.

📌 Description

WIP — opening early for visibility / review feedback. Not ready to merge: perf parity vs. TRT-LLM still TBD, and I'm still sweeping the unit tests.

During FlashInfer's port of the TRT-LLM gather+SwiGLU kernel to CuTe DSL Python, the tile_size=256 path (use_2cta_instrs=True, where two CTAs cooperate on a larger MMA operation) produces numerically incorrect results — the kernel runs but gives wrong answers. An NVIDIA engineer discovered this in PR #2775 and disabled it as a workaround, leaving only tile_size=128. Since TRT-LLM's original kernel works correctly with tile_size=256, this is a bug introduced during the porting process. It doesn't affect DWDP functionality, but it halves the autotuner's tactic search space and may cost some performance on large-batch workloads.

Summary

Updates the CuTe DSL NVFP4 MoE kernels to accept weights as a list of tensors split along the expert dimension (up to 4), in addition to the existing single-tensor layout. This lands the DWDP (Distributed Weight Data Parallelism) support that the CUTLASS/TRT-LLM side already has.

  • w1_weight / w1_weight_sf / w1_alpha and the w2_* counterparts now accept Union[Tensor, List[Tensor]]
  • Kernel is specialized per multi-B config via b_tensor_l_sizes; the right B tensor is selected from the runtime expert index on the kernel side
  • Wrapper / functional / tuner paths all updated for the new layout
  • Adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline

Ports the approach from NVIDIA/TensorRT-LLM#12136.

🔍 Related Issues

#3036

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Support for splitting expert weights/scale/alpha across up to 4 tensors for NVFP4 MoE ops; inputs may be a single tensor or a list with automatic expert-dimension handling and backward compatibility.
  • Documentation
    • Public docstrings updated to describe the single-tensor-or-list convention and expert-splitting behavior.
  • Tests
    • Added tests covering multi-tensor partitions, backward compatibility, and runtime/autotune execution.

Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels
and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as
either a single tensor or a list of up to 4 tensors split along the
expert dimension. The compiled kernel is specialized per multi-B config
via b_tensor_l_sizes, with kernel-side branching selecting the right B
tensor from the runtime expert index.

Also adds end-to-end tests verifying multi-B results match the single
stacked-tensor baseline.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

📝 Walkthrough
🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main feature: support for multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE, matching the core changeset.
Description check ✅ Passed The description comprehensively covers the changes, references related issue #3036, and includes required checklist sections, though some checklist items remain unchecked as this is a WIP PR.
Docstring Coverage ✅ Passed Docstring coverage is 90.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multiple B weight tensors (Distributed Weight Data Parallelism) in the Blackwell blockscaled MoE kernels. The changes enable the selection of B tensors and alpha values at runtime based on expert indices. Review feedback identified several critical issues: potential out-of-bounds accesses when retrieving alpha values in both gather and finalize kernels, and logic errors in the wrapper functions where the default single-B case (when b_tensor_l_sizes is None) leads to TypeError or incorrect layout dimensions due to offset padding.

Comment on lines +2977 to +3016
alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]]
if cutlass.const_expr(self.num_b_tensors == 1):
pass # Already initialized above
elif cutlass.const_expr(self.num_b_tensors == 2):
if expert_idx >= self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif cutlass.const_expr(self.num_b_tensors == 3):
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif expert_idx >= self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
else:
# 4 B tensors
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif (
expert_idx >= self.b_tensor_l_offsets[2]
and expert_idx < self.b_tensor_l_offsets[3]
):
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
elif expert_idx >= self.b_tensor_l_offsets[3]:
alpha_val = alpha_tuple[3][
expert_idx - self.b_tensor_l_offsets[3]
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initial assignment to alpha_val at line 2977 uses alpha_tuple[0] with an index that could be out of bounds if expert_idx belongs to a subsequent tensor (i.e., expert_idx >= self.b_tensor_l_offsets[1]). While alpha_tuple[0] is a cute.Tensor and indexing might just perform pointer arithmetic, it is safer and more correct to guard the access within the num_b_tensors and expert_idx branches to ensure only the valid tensor for the current expert is accessed.

                # Select alpha from correct tensor based on expert_idx
                if cutlass.const_expr(self.num_b_tensors == 1):
                    alpha_val = alpha_tuple[0][expert_idx]
                elif cutlass.const_expr(self.num_b_tensors == 2):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    else:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                elif cutlass.const_expr(self.num_b_tensors == 3):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    else:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                else:
                    # 4 B tensors
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    elif expert_idx < self.b_tensor_l_offsets[3]:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                    else:
                        alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]]

scale_k = k // scaling_vector_size
interm_size = n // 2
num_tiles = m // tile_size
total_l = self.b_tensor_l_offsets[self.num_b_tensors]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of total_l using self.b_tensor_l_offsets[self.num_b_tensors] is problematic when b_tensor_l_sizes is None (the generic single-B case). In that case, self.num_b_tensors is 1 and self.b_tensor_l_offsets[1] is padded with 2**30 (line 540), leading to an incorrect and massive dimension for the c_sf layout. Since l was removed from the wrapper signature, there is no runtime fallback for the expert count. Consider restoring l to the signature or ensuring b_tensor_l_sizes is always a valid tuple in __init__.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +2392 to +2431
alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]]
if cutlass.const_expr(self.num_b_tensors == 1):
pass # Already initialized above
elif cutlass.const_expr(self.num_b_tensors == 2):
if expert_idx >= self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif cutlass.const_expr(self.num_b_tensors == 3):
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif expert_idx >= self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
else:
# 4 B tensors
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif (
expert_idx >= self.b_tensor_l_offsets[2]
and expert_idx < self.b_tensor_l_offsets[3]
):
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
elif expert_idx >= self.b_tensor_l_offsets[3]:
alpha_val = alpha_tuple[3][
expert_idx - self.b_tensor_l_offsets[3]
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the gather kernel, the initial assignment to alpha_val at line 2392 performs a potentially out-of-bounds access on alpha_tuple[0] when expert_idx belongs to a later tensor. The access should be moved inside the conditional branches.

                # Select alpha from correct tensor based on expert_idx
                if cutlass.const_expr(self.num_b_tensors == 1):
                    alpha_val = alpha_tuple[0][expert_idx]
                elif cutlass.const_expr(self.num_b_tensors == 2):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    else:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                elif cutlass.const_expr(self.num_b_tensors == 3):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    else:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                else:
                    # 4 B tensors
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    elif expert_idx < self.b_tensor_l_offsets[3]:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                    else:
                        alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]]

alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,)))

# Create B and alpha tensors using const_expr conditions
l_0 = self.b_tensor_l_sizes[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing self.b_tensor_l_sizes[0] will raise a TypeError if b_tensor_l_sizes is None (the generic single-B case). The wrapper should handle the case where expert sizes are not provided at initialization, likely by restoring the l parameter to the signature.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 2975-3016: The initial unconditional assignment to alpha_val from
alpha_tuple[0] can index out-of-bounds for expert_idx >=
self.b_tensor_l_offsets[1]; update the selection to mirror the B-tensor
selection pattern used elsewhere: remove the unconditional alpha_tuple[0] read
and implement explicit range checks against self.b_tensor_l_offsets for each
branch of cutlass.const_expr(self.num_b_tensors) so alpha_val is only read from
alpha_tuple[i] when expert_idx is within that tensor's [start, end) range; use
the same ordering and guards involving expert_idx, self.b_tensor_l_offsets,
num_b_tensors, alpha_tuple and alpha_val as in the B-tensor selection logic to
ensure safe indexing.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 3192-3194: wrapper() currently indexes self.b_tensor_l_sizes[0]
(l_0) without ensuring b_tensor_l_sizes was provided in __init__, causing a
NoneType subscript; add an explicit guard at the start of wrapper() that checks
if self.b_tensor_l_sizes is None and raises a clear ValueError explaining that
b_tensor_l_sizes must be set for multi-B fused kernels (or alternatively make
b_tensor_l_sizes a required constructor parameter in __init__); update
references around l_0 and alpha_0 to rely on this validation so the subsequent
cute.make_tensor(...) call is safe.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 378-381: The code now accepts list inputs for b, b_scale, and
alpha but lacks validation: ensure b_list, b_scale_list, and alpha_list are
non-empty, have identical lengths, and that every corresponding tensor split
agrees on the non-expert dimensions (shapes except the expert-split dimension)
before any indexing or kernel compilation (e.g., before using b_list[0] or
passing these lists into the compiled kernel path). Add explicit checks that
raise a clear error if any list is empty, if len(b_list) != len(b_scale_list) !=
len(alpha_list), or if any pairwise tensor shape mismatch exists on non-expert
dims; apply the same validation logic around the other normalization sites noted
(the blocks around the other occurrences you flagged: the b/b_scale/alpha
normalization at the later sections).

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1061-1062: Replace the file-local test gate decorator
`sm100_required` with the repo-standard check
`flashinfer.utils.is_sm100a_supported()` on the new test class so it uses the
canonical GPU support helper; locate the class decorated with
`@cute_dsl_available` and `@sm100_required`, remove `@sm100_required` and apply
the skip/require helper that calls `flashinfer.utils.is_sm100a_supported()` (or
the equivalent test-skip decorator that invokes it) so the test skips correctly
on unsupported SM100a devices instead of relying on the local `sm100_required`
function.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f08ff5cf-440b-421f-b80f-4ac6097cb798

📥 Commits

Reviewing files that changed from the base of the PR and between b75740d and 038bf93.

📒 Files selected for processing (7)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment on lines +378 to +381
# Normalize to lists for multi-B support
b_list = [b] if isinstance(b, torch.Tensor) else b
b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale
alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the new multi-B inputs before indexing and compiling.

This path accepts lists now, but it never checks that they are non-empty, that b/b_scale/alpha have the same number of splits, or that every split agrees on the non-expert dimensions. Right now [] fails at b_list[0], and mismatched split counts/shapes fall through to obscure tuple-index/layout errors in the compiled kernel path.

🧩 Suggested validation
     b_list = [b] if isinstance(b, torch.Tensor) else b
     b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale
     alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha
+
+    if not b_list:
+        raise ValueError("b must be a tensor or a non-empty list of tensors")
+    if not (len(b_list) == len(b_scale_list) == len(alpha_list)):
+        raise ValueError("b, b_scale, and alpha must use the same number of splits")
+    if len(b_list) > 4:
+        raise ValueError("at most 4 B tensors are supported")
+
+    ref_n = b_list[0].shape[1]
+    ref_packed_k = b_list[0].shape[2]
+    for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+        if bi.shape[1:] != (ref_n, ref_packed_k):
+            raise ValueError(f"split {i} has inconsistent B shape: {tuple(bi.shape)}")
+        if bsi.shape[-1] != bi.shape[0]:
+            raise ValueError(
+                f"split {i} has inconsistent B-scale expert dim: {bsi.shape[-1]} != {bi.shape[0]}"
+            )
+        if ai.numel() != bi.shape[0]:
+            raise ValueError(
+                f"split {i} has inconsistent alpha length: {ai.numel()} != {bi.shape[0]}"
+            )

Also applies to: 385-390, 456-486

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 378 - 381, The code now accepts list inputs for b, b_scale, and
alpha but lacks validation: ensure b_list, b_scale_list, and alpha_list are
non-empty, have identical lengths, and that every corresponding tensor split
agrees on the non-expert dimensions (shapes except the expert-split dimension)
before any indexing or kernel compilation (e.g., before using b_list[0] or
passing these lists into the compiled kernel path). Add explicit checks that
raise a clear error if any list is empty, if len(b_list) != len(b_scale_list) !=
len(alpha_list), or if any pairwise tensor shape mismatch exists on non-expert
dims; apply the same validation logic around the other normalization sites noted
(the blocks around the other occurrences you flagged: the b/b_scale/alpha
normalization at the later sections).

Comment on lines +1061 to +1062
@cute_dsl_available
@sm100_required
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use the repo-standard SM100 skip helper for this new coverage.

These tests still rely on the file-local sm100_required gate, which only checks props.major == 10. That can enable the new multi-B cases on unsupported 10.x parts and create spurious failures. Please switch the new class to flashinfer.utils.is_sm100a_supported() instead.

As per coding guidelines: "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 1061 - 1062, Replace the
file-local test gate decorator `sm100_required` with the repo-standard check
`flashinfer.utils.is_sm100a_supported()` on the new test class so it uses the
canonical GPU support helper; locate the class decorated with
`@cute_dsl_available` and `@sm100_required`, remove `@sm100_required` and apply
the skip/require helper that calls `flashinfer.utils.is_sm100a_supported()` (or
the equivalent test-skip decorator that invokes it) so the test skips correctly
on unsupported SM100a devices instead of relying on the local `sm100_required`
function.

- Safe alpha indexing with pre-initialization before const_expr branches
- NoneType guard: raise ValueError when b_tensor_l_sizes=None
- Input validation for multi-B weight lists (empty, max 4, length match)
- Fix test imports to use top-level flashinfer module

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/moe/test_cute_dsl_fused_moe.py (1)

1420-1421: Inconsistent import pattern for autotune.

Line 1420 uses from flashinfer.autotuner import autotune, but the rest of the file (e.g., line 408) uses from flashinfer import autotune. For consistency within this file, prefer the top-level import.

♻️ Proposed fix for import consistency
-        from flashinfer.autotuner import autotune
+        from flashinfer import autotune
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 1420 - 1421, The file
mixes import styles for autotune; change the line that reads "from
flashinfer.autotuner import autotune" to the top-level form "from flashinfer
import autotune" so imports are consistent with other uses in this test (e.g.,
the earlier import at line ~408); leave the import of cute_dsl_fused_moe_nvfp4
unchanged.
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

463-465: Use ValueError instead of assert for input validation consistency.

Lines 458-462 correctly use ValueError for the None check, but lines 463-465 use assert for the length check. Assertions can be disabled with Python's -O flag, making this validation bypassable in optimized builds.

♻️ Proposed fix for consistent error handling
-        assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, (
-            f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}"
-        )
+        if len(b_tensor_l_sizes) > self.MAX_B_TENSORS:
+            raise ValueError(
+                f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 463 - 465, Replace the runtime assertion with a ValueError to
ensure input validation cannot be disabled: where the code currently does
"assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ..." raise a ValueError
with a descriptive message (referencing b_tensor_l_sizes and self.MAX_B_TENSORS)
so the length check always executes (e.g., in the same scope as the existing
None check that uses ValueError).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 4063-4067: The SFC tensor c_sf is being created with its L
dimension set to total_l which mismatches the backing allocation; change the
layout passed to cute.make_ordered_layout when creating c_sf (in the c_sf =
cute.make_tensor(...) call using c_sf_ptr) so the last dimension is 1 instead of
total_l — i.e. use (32, 4, m // 128, 4, interm_size // (scaling_vector_size *
4), 1) with the same order=(2, 1, 4, 0, 3, 5) so the tensor’s L dimension
remains 1 while keeping the same ordering.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 418-439: Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.

---

Nitpick comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 463-465: Replace the runtime assertion with a ValueError to ensure
input validation cannot be disabled: where the code currently does "assert
len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ..." raise a ValueError with a
descriptive message (referencing b_tensor_l_sizes and self.MAX_B_TENSORS) so the
length check always executes (e.g., in the same scope as the existing None check
that uses ValueError).

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1420-1421: The file mixes import styles for autotune; change the
line that reads "from flashinfer.autotuner import autotune" to the top-level
form "from flashinfer import autotune" so imports are consistent with other uses
in this test (e.g., the earlier import at line ~408); leave the import of
cute_dsl_fused_moe_nvfp4 unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 99f12819-35d9-4237-9b2e-4093fe716fc2

📥 Commits

Reviewing files that changed from the base of the PR and between 038bf93 and 9edbe9b.

📒 Files selected for processing (5)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • tests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Comment on lines 4063 to 4067
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l),
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
order=(2, 1, 4, 0, 3, 5),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep the SFC tensor's L dimension at 1.

out_scale is still allocated by the public wrapper as (..., 1), but this wrapper now reinterprets the same buffer as (..., total_l). On multi-B FP4 paths that gives CuTe a different layout than the backing allocation, and the output tensor still has flattened M x N x 1 semantics here.

Proposed fix
         c_sf = cute.make_tensor(
             c_sf_ptr,
             layout=cute.make_ordered_layout(
-                (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
+                (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), 1),
                 order=(2, 1, 4, 0, 3, 5),
             ),
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l),
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
order=(2, 1, 4, 0, 3, 5),
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), 1),
order=(2, 1, 4, 0, 3, 5),
),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 4063 - 4067, The SFC tensor c_sf is being created with its L
dimension set to total_l which mismatches the backing allocation; change the
layout passed to cute.make_ordered_layout when creating c_sf (in the c_sf =
cute.make_tensor(...) call using c_sf_ptr) so the last dimension is 1 instead of
total_l — i.e. use (32, 4, m // 128, 4, interm_size // (scaling_vector_size *
4), 1) with the same order=(2, 1, 4, 0, 3, 5) so the tensor’s L dimension
remains 1 while keeping the same ordering.

Comment on lines +418 to 439
# Normalize to lists for multi-B support
b_list = [b] if isinstance(b, torch.Tensor) else list(b)
b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else list(b_scale)
alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else list(alpha)

# Validate multi-B inputs
assert len(b_list) > 0, "Weight tensor list must not be empty"
assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
assert len(b_list) == len(b_scale_list) == len(alpha_list), (
f"b, b_scale, alpha lists must have same length: "
f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
)

# Validate inputs
assert a.device.type == "cuda", "Input tensors must be on CUDA device"
assert b.device.type == "cuda", "Input tensors must be on CUDA device"
assert b_list[0].device.type == "cuda", "Input tensors must be on CUDA device"

# Get dimensions
seq_len = a.shape[0]
num_experts = b.shape[0]
n = b.shape[1] # This is 2*intermediate_size
num_experts = sum(bi.size(0) for bi in b_list)
n = b_list[0].shape[1] # This is 2*intermediate_size
k = a.shape[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate each split tensor, not just the list length.

The new checks still allow invalid multi-B configurations: an empty first split, a later split on a different device, a later split with different (N, K), or an alpha[i] whose length does not match b[i].size(0). Those all flow into the kernel, which assumes split 0 exists and reuses the first split's shape for every other split.

Proposed fix
-    assert len(b_list) > 0, "Weight tensor list must not be empty"
-    assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
-    assert len(b_list) == len(b_scale_list) == len(alpha_list), (
-        f"b, b_scale, alpha lists must have same length: "
-        f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
-    )
+    if len(b_list) == 0:
+        raise ValueError("Weight tensor list must not be empty")
+    if len(b_list) > 4:
+        raise ValueError(f"Maximum 4 weight tensors supported, got {len(b_list)}")
+    if len(b_list) != len(b_scale_list) or len(b_list) != len(alpha_list):
+        raise ValueError(
+            "b, b_scale, and alpha must contain the same number of splits"
+        )
+
+    ref_nk = b_list[0].shape[1:]
+    for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+        if bi.size(0) == 0:
+            raise ValueError(f"b[{i}] must contain at least one expert")
+        if bi.device != a.device or bsi.device != a.device or ai.device != a.device:
+            raise ValueError(
+                f"All split tensors must be on {a.device}; "
+                f"got b[{i}]={bi.device}, b_scale[{i}]={bsi.device}, alpha[{i}]={ai.device}"
+            )
+        if bi.shape[1:] != ref_nk:
+            raise ValueError(
+                f"All B splits must share the same (N, K); "
+                f"expected {ref_nk}, got {bi.shape[1:]} for b[{i}]"
+            )
+        if ai.numel() != bi.size(0):
+            raise ValueError(
+                f"alpha[{i}] must have {bi.size(0)} entries, got {ai.numel()}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 418 - 439, Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !550 has been created, and the CI pipeline #48542959 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contribution to the project!
I think we want to know if there is a performance regression with the non-data-parallelism case.
Moreover, the autotuner design might need more consideration. Right now it depends on the number of experts as before. This might provide sub-optimal performance on data parallelism case. Do we need to also consider keep the number of parallelism as an parameter as well?

# as a tuple (even for single-B, e.g. (256,)).
if b_tensor_l_sizes is None:
raise ValueError(
"b_tensor_l_sizes is required. Pass a tuple with the number of "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the description of b_tensor_l_sizes between here (required) and the function signature (optional) is different.

topk: cutlass.Int64,
raster_along_m: bool = False,
enable_pdl: bool = True,
b_tensor_l_sizes: Optional[Tuple[int, ...]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to put the parameter of enalbe_pdl as the last parameter, as TensorRT-LLM doesn't contain such parameter. It would be easier for developer when porting the optimization

scale_k = k // scaling_vector_size
interm_size = n // 2
num_tiles = m // tile_size
total_l = self.b_tensor_l_offsets[self.num_b_tensors]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants