[WIP] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE#3041
[WIP] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE#3041yhyang201 wants to merge 2 commits intoflashinfer-ai:mainfrom
Conversation
Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as either a single tensor or a list of up to 4 tensors split along the expert dimension. The compiled kernel is specialized per multi-B config via b_tensor_l_sizes, with kernel-side branching selecting the right B tensor from the runtime expert index. Also adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 Walkthrough🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for multiple B weight tensors (Distributed Weight Data Parallelism) in the Blackwell blockscaled MoE kernels. The changes enable the selection of B tensors and alpha values at runtime based on expert indices. Review feedback identified several critical issues: potential out-of-bounds accesses when retrieving alpha values in both gather and finalize kernels, and logic errors in the wrapper functions where the default single-B case (when b_tensor_l_sizes is None) leads to TypeError or incorrect layout dimensions due to offset padding.
| alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] | ||
| if cutlass.const_expr(self.num_b_tensors == 1): | ||
| pass # Already initialized above | ||
| elif cutlass.const_expr(self.num_b_tensors == 2): | ||
| if expert_idx >= self.b_tensor_l_offsets[1]: | ||
| alpha_val = alpha_tuple[1][ | ||
| expert_idx - self.b_tensor_l_offsets[1] | ||
| ] | ||
| elif cutlass.const_expr(self.num_b_tensors == 3): | ||
| if ( | ||
| expert_idx >= self.b_tensor_l_offsets[1] | ||
| and expert_idx < self.b_tensor_l_offsets[2] | ||
| ): | ||
| alpha_val = alpha_tuple[1][ | ||
| expert_idx - self.b_tensor_l_offsets[1] | ||
| ] | ||
| elif expert_idx >= self.b_tensor_l_offsets[2]: | ||
| alpha_val = alpha_tuple[2][ | ||
| expert_idx - self.b_tensor_l_offsets[2] | ||
| ] | ||
| else: | ||
| # 4 B tensors | ||
| if ( | ||
| expert_idx >= self.b_tensor_l_offsets[1] | ||
| and expert_idx < self.b_tensor_l_offsets[2] | ||
| ): | ||
| alpha_val = alpha_tuple[1][ | ||
| expert_idx - self.b_tensor_l_offsets[1] | ||
| ] | ||
| elif ( | ||
| expert_idx >= self.b_tensor_l_offsets[2] | ||
| and expert_idx < self.b_tensor_l_offsets[3] | ||
| ): | ||
| alpha_val = alpha_tuple[2][ | ||
| expert_idx - self.b_tensor_l_offsets[2] | ||
| ] | ||
| elif expert_idx >= self.b_tensor_l_offsets[3]: | ||
| alpha_val = alpha_tuple[3][ | ||
| expert_idx - self.b_tensor_l_offsets[3] | ||
| ] |
There was a problem hiding this comment.
The initial assignment to alpha_val at line 2977 uses alpha_tuple[0] with an index that could be out of bounds if expert_idx belongs to a subsequent tensor (i.e., expert_idx >= self.b_tensor_l_offsets[1]). While alpha_tuple[0] is a cute.Tensor and indexing might just perform pointer arithmetic, it is safer and more correct to guard the access within the num_b_tensors and expert_idx branches to ensure only the valid tensor for the current expert is accessed.
# Select alpha from correct tensor based on expert_idx
if cutlass.const_expr(self.num_b_tensors == 1):
alpha_val = alpha_tuple[0][expert_idx]
elif cutlass.const_expr(self.num_b_tensors == 2):
if expert_idx < self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[0][expert_idx]
else:
alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
elif cutlass.const_expr(self.num_b_tensors == 3):
if expert_idx < self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[0][expert_idx]
elif expert_idx < self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
else:
alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
else:
# 4 B tensors
if expert_idx < self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[0][expert_idx]
elif expert_idx < self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
elif expert_idx < self.b_tensor_l_offsets[3]:
alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
else:
alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]]| scale_k = k // scaling_vector_size | ||
| interm_size = n // 2 | ||
| num_tiles = m // tile_size | ||
| total_l = self.b_tensor_l_offsets[self.num_b_tensors] |
There was a problem hiding this comment.
The calculation of total_l using self.b_tensor_l_offsets[self.num_b_tensors] is problematic when b_tensor_l_sizes is None (the generic single-B case). In that case, self.num_b_tensors is 1 and self.b_tensor_l_offsets[1] is padded with 2**30 (line 540), leading to an incorrect and massive dimension for the c_sf layout. Since l was removed from the wrapper signature, there is no runtime fallback for the expert count. Consider restoring l to the signature or ensuring b_tensor_l_sizes is always a valid tuple in __init__.
| alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] | ||
| if cutlass.const_expr(self.num_b_tensors == 1): | ||
| pass # Already initialized above | ||
| elif cutlass.const_expr(self.num_b_tensors == 2): | ||
| if expert_idx >= self.b_tensor_l_offsets[1]: | ||
| alpha_val = alpha_tuple[1][ | ||
| expert_idx - self.b_tensor_l_offsets[1] | ||
| ] | ||
| elif cutlass.const_expr(self.num_b_tensors == 3): | ||
| if ( | ||
| expert_idx >= self.b_tensor_l_offsets[1] | ||
| and expert_idx < self.b_tensor_l_offsets[2] | ||
| ): | ||
| alpha_val = alpha_tuple[1][ | ||
| expert_idx - self.b_tensor_l_offsets[1] | ||
| ] | ||
| elif expert_idx >= self.b_tensor_l_offsets[2]: | ||
| alpha_val = alpha_tuple[2][ | ||
| expert_idx - self.b_tensor_l_offsets[2] | ||
| ] | ||
| else: | ||
| # 4 B tensors | ||
| if ( | ||
| expert_idx >= self.b_tensor_l_offsets[1] | ||
| and expert_idx < self.b_tensor_l_offsets[2] | ||
| ): | ||
| alpha_val = alpha_tuple[1][ | ||
| expert_idx - self.b_tensor_l_offsets[1] | ||
| ] | ||
| elif ( | ||
| expert_idx >= self.b_tensor_l_offsets[2] | ||
| and expert_idx < self.b_tensor_l_offsets[3] | ||
| ): | ||
| alpha_val = alpha_tuple[2][ | ||
| expert_idx - self.b_tensor_l_offsets[2] | ||
| ] | ||
| elif expert_idx >= self.b_tensor_l_offsets[3]: | ||
| alpha_val = alpha_tuple[3][ | ||
| expert_idx - self.b_tensor_l_offsets[3] | ||
| ] |
There was a problem hiding this comment.
Similar to the gather kernel, the initial assignment to alpha_val at line 2392 performs a potentially out-of-bounds access on alpha_tuple[0] when expert_idx belongs to a later tensor. The access should be moved inside the conditional branches.
# Select alpha from correct tensor based on expert_idx
if cutlass.const_expr(self.num_b_tensors == 1):
alpha_val = alpha_tuple[0][expert_idx]
elif cutlass.const_expr(self.num_b_tensors == 2):
if expert_idx < self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[0][expert_idx]
else:
alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
elif cutlass.const_expr(self.num_b_tensors == 3):
if expert_idx < self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[0][expert_idx]
elif expert_idx < self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
else:
alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
else:
# 4 B tensors
if expert_idx < self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[0][expert_idx]
elif expert_idx < self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
elif expert_idx < self.b_tensor_l_offsets[3]:
alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
else:
alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]]| alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) | ||
|
|
||
| # Create B and alpha tensors using const_expr conditions | ||
| l_0 = self.b_tensor_l_sizes[0] |
There was a problem hiding this comment.
Actionable comments posted: 4
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 2975-3016: The initial unconditional assignment to alpha_val from
alpha_tuple[0] can index out-of-bounds for expert_idx >=
self.b_tensor_l_offsets[1]; update the selection to mirror the B-tensor
selection pattern used elsewhere: remove the unconditional alpha_tuple[0] read
and implement explicit range checks against self.b_tensor_l_offsets for each
branch of cutlass.const_expr(self.num_b_tensors) so alpha_val is only read from
alpha_tuple[i] when expert_idx is within that tensor's [start, end) range; use
the same ordering and guards involving expert_idx, self.b_tensor_l_offsets,
num_b_tensors, alpha_tuple and alpha_val as in the B-tensor selection logic to
ensure safe indexing.
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 3192-3194: wrapper() currently indexes self.b_tensor_l_sizes[0]
(l_0) without ensuring b_tensor_l_sizes was provided in __init__, causing a
NoneType subscript; add an explicit guard at the start of wrapper() that checks
if self.b_tensor_l_sizes is None and raises a clear ValueError explaining that
b_tensor_l_sizes must be set for multi-B fused kernels (or alternatively make
b_tensor_l_sizes a required constructor parameter in __init__); update
references around l_0 and alpha_0 to rely on this validation so the subsequent
cute.make_tensor(...) call is safe.
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 378-381: The code now accepts list inputs for b, b_scale, and
alpha but lacks validation: ensure b_list, b_scale_list, and alpha_list are
non-empty, have identical lengths, and that every corresponding tensor split
agrees on the non-expert dimensions (shapes except the expert-split dimension)
before any indexing or kernel compilation (e.g., before using b_list[0] or
passing these lists into the compiled kernel path). Add explicit checks that
raise a clear error if any list is empty, if len(b_list) != len(b_scale_list) !=
len(alpha_list), or if any pairwise tensor shape mismatch exists on non-expert
dims; apply the same validation logic around the other normalization sites noted
(the blocks around the other occurrences you flagged: the b/b_scale/alpha
normalization at the later sections).
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1061-1062: Replace the file-local test gate decorator
`sm100_required` with the repo-standard check
`flashinfer.utils.is_sm100a_supported()` on the new test class so it uses the
canonical GPU support helper; locate the class decorated with
`@cute_dsl_available` and `@sm100_required`, remove `@sm100_required` and apply
the skip/require helper that calls `flashinfer.utils.is_sm100a_supported()` (or
the equivalent test-skip decorator that invokes it) so the test skips correctly
on unsupported SM100a devices instead of relying on the local `sm100_required`
function.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f08ff5cf-440b-421f-b80f-4ac6097cb798
📒 Files selected for processing (7)
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/fused_moe.pyflashinfer/fused_moe/cute_dsl/tuner.pytests/moe/test_cute_dsl_fused_moe.py
| # Normalize to lists for multi-B support | ||
| b_list = [b] if isinstance(b, torch.Tensor) else b | ||
| b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale | ||
| alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha |
There was a problem hiding this comment.
Validate the new multi-B inputs before indexing and compiling.
This path accepts lists now, but it never checks that they are non-empty, that b/b_scale/alpha have the same number of splits, or that every split agrees on the non-expert dimensions. Right now [] fails at b_list[0], and mismatched split counts/shapes fall through to obscure tuple-index/layout errors in the compiled kernel path.
🧩 Suggested validation
b_list = [b] if isinstance(b, torch.Tensor) else b
b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale
alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha
+
+ if not b_list:
+ raise ValueError("b must be a tensor or a non-empty list of tensors")
+ if not (len(b_list) == len(b_scale_list) == len(alpha_list)):
+ raise ValueError("b, b_scale, and alpha must use the same number of splits")
+ if len(b_list) > 4:
+ raise ValueError("at most 4 B tensors are supported")
+
+ ref_n = b_list[0].shape[1]
+ ref_packed_k = b_list[0].shape[2]
+ for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+ if bi.shape[1:] != (ref_n, ref_packed_k):
+ raise ValueError(f"split {i} has inconsistent B shape: {tuple(bi.shape)}")
+ if bsi.shape[-1] != bi.shape[0]:
+ raise ValueError(
+ f"split {i} has inconsistent B-scale expert dim: {bsi.shape[-1]} != {bi.shape[0]}"
+ )
+ if ai.numel() != bi.shape[0]:
+ raise ValueError(
+ f"split {i} has inconsistent alpha length: {ai.numel()} != {bi.shape[0]}"
+ )Also applies to: 385-390, 456-486
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 378 - 381, The code now accepts list inputs for b, b_scale, and
alpha but lacks validation: ensure b_list, b_scale_list, and alpha_list are
non-empty, have identical lengths, and that every corresponding tensor split
agrees on the non-expert dimensions (shapes except the expert-split dimension)
before any indexing or kernel compilation (e.g., before using b_list[0] or
passing these lists into the compiled kernel path). Add explicit checks that
raise a clear error if any list is empty, if len(b_list) != len(b_scale_list) !=
len(alpha_list), or if any pairwise tensor shape mismatch exists on non-expert
dims; apply the same validation logic around the other normalization sites noted
(the blocks around the other occurrences you flagged: the b/b_scale/alpha
normalization at the later sections).
| @cute_dsl_available | ||
| @sm100_required |
There was a problem hiding this comment.
Use the repo-standard SM100 skip helper for this new coverage.
These tests still rely on the file-local sm100_required gate, which only checks props.major == 10. That can enable the new multi-B cases on unsupported 10.x parts and create spurious failures. Please switch the new class to flashinfer.utils.is_sm100a_supported() instead.
As per coding guidelines: "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 1061 - 1062, Replace the
file-local test gate decorator `sm100_required` with the repo-standard check
`flashinfer.utils.is_sm100a_supported()` on the new test class so it uses the
canonical GPU support helper; locate the class decorated with
`@cute_dsl_available` and `@sm100_required`, remove `@sm100_required` and apply
the skip/require helper that calls `flashinfer.utils.is_sm100a_supported()` (or
the equivalent test-skip decorator that invokes it) so the test skips correctly
on unsupported SM100a devices instead of relying on the local `sm100_required`
function.
- Safe alpha indexing with pre-initialization before const_expr branches - NoneType guard: raise ValueError when b_tensor_l_sizes=None - Input validation for multi-B weight lists (empty, max 4, length match) - Fix test imports to use top-level flashinfer module Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/moe/test_cute_dsl_fused_moe.py (1)
1420-1421: Inconsistent import pattern forautotune.Line 1420 uses
from flashinfer.autotuner import autotune, but the rest of the file (e.g., line 408) usesfrom flashinfer import autotune. For consistency within this file, prefer the top-level import.♻️ Proposed fix for import consistency
- from flashinfer.autotuner import autotune + from flashinfer import autotune🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 1420 - 1421, The file mixes import styles for autotune; change the line that reads "from flashinfer.autotuner import autotune" to the top-level form "from flashinfer import autotune" so imports are consistent with other uses in this test (e.g., the earlier import at line ~408); leave the import of cute_dsl_fused_moe_nvfp4 unchanged.flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)
463-465: UseValueErrorinstead ofassertfor input validation consistency.Lines 458-462 correctly use
ValueErrorfor theNonecheck, but lines 463-465 useassertfor the length check. Assertions can be disabled with Python's-Oflag, making this validation bypassable in optimized builds.♻️ Proposed fix for consistent error handling
- assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( - f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" - ) + if len(b_tensor_l_sizes) > self.MAX_B_TENSORS: + raise ValueError( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py` around lines 463 - 465, Replace the runtime assertion with a ValueError to ensure input validation cannot be disabled: where the code currently does "assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ..." raise a ValueError with a descriptive message (referencing b_tensor_l_sizes and self.MAX_B_TENSORS) so the length check always executes (e.g., in the same scope as the existing None check that uses ValueError).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 4063-4067: The SFC tensor c_sf is being created with its L
dimension set to total_l which mismatches the backing allocation; change the
layout passed to cute.make_ordered_layout when creating c_sf (in the c_sf =
cute.make_tensor(...) call using c_sf_ptr) so the last dimension is 1 instead of
total_l — i.e. use (32, 4, m // 128, 4, interm_size // (scaling_vector_size *
4), 1) with the same order=(2, 1, 4, 0, 3, 5) so the tensor’s L dimension
remains 1 while keeping the same ordering.
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 418-439: Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.
---
Nitpick comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 463-465: Replace the runtime assertion with a ValueError to ensure
input validation cannot be disabled: where the code currently does "assert
len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ..." raise a ValueError with a
descriptive message (referencing b_tensor_l_sizes and self.MAX_B_TENSORS) so the
length check always executes (e.g., in the same scope as the existing None check
that uses ValueError).
In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1420-1421: The file mixes import styles for autotune; change the
line that reads "from flashinfer.autotuner import autotune" to the top-level
form "from flashinfer import autotune" so imports are consistent with other uses
in this test (e.g., the earlier import at line ~408); leave the import of
cute_dsl_fused_moe_nvfp4 unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 99f12819-35d9-4237-9b2e-4093fe716fc2
📒 Files selected for processing (5)
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.pyflashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.pytests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
- flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
| c_sf = cute.make_tensor( | ||
| c_sf_ptr, | ||
| layout=cute.make_ordered_layout( | ||
| (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l), | ||
| (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l), | ||
| order=(2, 1, 4, 0, 3, 5), |
There was a problem hiding this comment.
Keep the SFC tensor's L dimension at 1.
out_scale is still allocated by the public wrapper as (..., 1), but this wrapper now reinterprets the same buffer as (..., total_l). On multi-B FP4 paths that gives CuTe a different layout than the backing allocation, and the output tensor still has flattened M x N x 1 semantics here.
Proposed fix
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
- (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
+ (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), 1),
order=(2, 1, 4, 0, 3, 5),
),
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| c_sf = cute.make_tensor( | |
| c_sf_ptr, | |
| layout=cute.make_ordered_layout( | |
| (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l), | |
| (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l), | |
| order=(2, 1, 4, 0, 3, 5), | |
| c_sf = cute.make_tensor( | |
| c_sf_ptr, | |
| layout=cute.make_ordered_layout( | |
| (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), 1), | |
| order=(2, 1, 4, 0, 3, 5), | |
| ), | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 4063 - 4067, The SFC tensor c_sf is being created with its L
dimension set to total_l which mismatches the backing allocation; change the
layout passed to cute.make_ordered_layout when creating c_sf (in the c_sf =
cute.make_tensor(...) call using c_sf_ptr) so the last dimension is 1 instead of
total_l — i.e. use (32, 4, m // 128, 4, interm_size // (scaling_vector_size *
4), 1) with the same order=(2, 1, 4, 0, 3, 5) so the tensor’s L dimension
remains 1 while keeping the same ordering.
| # Normalize to lists for multi-B support | ||
| b_list = [b] if isinstance(b, torch.Tensor) else list(b) | ||
| b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else list(b_scale) | ||
| alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else list(alpha) | ||
|
|
||
| # Validate multi-B inputs | ||
| assert len(b_list) > 0, "Weight tensor list must not be empty" | ||
| assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}" | ||
| assert len(b_list) == len(b_scale_list) == len(alpha_list), ( | ||
| f"b, b_scale, alpha lists must have same length: " | ||
| f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}" | ||
| ) | ||
|
|
||
| # Validate inputs | ||
| assert a.device.type == "cuda", "Input tensors must be on CUDA device" | ||
| assert b.device.type == "cuda", "Input tensors must be on CUDA device" | ||
| assert b_list[0].device.type == "cuda", "Input tensors must be on CUDA device" | ||
|
|
||
| # Get dimensions | ||
| seq_len = a.shape[0] | ||
| num_experts = b.shape[0] | ||
| n = b.shape[1] # This is 2*intermediate_size | ||
| num_experts = sum(bi.size(0) for bi in b_list) | ||
| n = b_list[0].shape[1] # This is 2*intermediate_size | ||
| k = a.shape[1] |
There was a problem hiding this comment.
Validate each split tensor, not just the list length.
The new checks still allow invalid multi-B configurations: an empty first split, a later split on a different device, a later split with different (N, K), or an alpha[i] whose length does not match b[i].size(0). Those all flow into the kernel, which assumes split 0 exists and reuses the first split's shape for every other split.
Proposed fix
- assert len(b_list) > 0, "Weight tensor list must not be empty"
- assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
- assert len(b_list) == len(b_scale_list) == len(alpha_list), (
- f"b, b_scale, alpha lists must have same length: "
- f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
- )
+ if len(b_list) == 0:
+ raise ValueError("Weight tensor list must not be empty")
+ if len(b_list) > 4:
+ raise ValueError(f"Maximum 4 weight tensors supported, got {len(b_list)}")
+ if len(b_list) != len(b_scale_list) or len(b_list) != len(alpha_list):
+ raise ValueError(
+ "b, b_scale, and alpha must contain the same number of splits"
+ )
+
+ ref_nk = b_list[0].shape[1:]
+ for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+ if bi.size(0) == 0:
+ raise ValueError(f"b[{i}] must contain at least one expert")
+ if bi.device != a.device or bsi.device != a.device or ai.device != a.device:
+ raise ValueError(
+ f"All split tensors must be on {a.device}; "
+ f"got b[{i}]={bi.device}, b_scale[{i}]={bsi.device}, alpha[{i}]={ai.device}"
+ )
+ if bi.shape[1:] != ref_nk:
+ raise ValueError(
+ f"All B splits must share the same (N, K); "
+ f"expected {ref_nk}, got {bi.shape[1:]} for b[{i}]"
+ )
+ if ai.numel() != bi.size(0):
+ raise ValueError(
+ f"alpha[{i}] must have {bi.size(0)} entries, got {ai.numel()}"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 418 - 439, Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.
|
/bot run |
There was a problem hiding this comment.
Thanks for contribution to the project!
I think we want to know if there is a performance regression with the non-data-parallelism case.
Moreover, the autotuner design might need more consideration. Right now it depends on the number of experts as before. This might provide sub-optimal performance on data parallelism case. Do we need to also consider keep the number of parallelism as an parameter as well?
| # as a tuple (even for single-B, e.g. (256,)). | ||
| if b_tensor_l_sizes is None: | ||
| raise ValueError( | ||
| "b_tensor_l_sizes is required. Pass a tuple with the number of " |
There was a problem hiding this comment.
It seems the description of b_tensor_l_sizes between here (required) and the function signature (optional) is different.
| topk: cutlass.Int64, | ||
| raster_along_m: bool = False, | ||
| enable_pdl: bool = True, | ||
| b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, |
There was a problem hiding this comment.
I prefer to put the parameter of enalbe_pdl as the last parameter, as TensorRT-LLM doesn't contain such parameter. It would be easier for developer when porting the optimization
| scale_k = k // scaling_vector_size | ||
| interm_size = n // 2 | ||
| num_tiles = m // tile_size | ||
| total_l = self.b_tensor_l_offsets[self.num_b_tensors] |
Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as either a single tensor or a list of up to 4 tensors split along the expert dimension. The compiled kernel is specialized per multi-B config via b_tensor_l_sizes, with kernel-side branching selecting the right B tensor from the runtime expert index.
Also adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline.
📌 Description
During FlashInfer's port of the TRT-LLM gather+SwiGLU kernel to CuTe DSL Python, the
tile_size=256path (use_2cta_instrs=True, where two CTAs cooperate on a larger MMA operation) produces numerically incorrect results — the kernel runs but gives wrong answers. An NVIDIA engineer discovered this in PR #2775 and disabled it as a workaround, leaving onlytile_size=128. Since TRT-LLM's original kernel works correctly withtile_size=256, this is a bug introduced during the porting process. It doesn't affect DWDP functionality, but it halves the autotuner's tactic search space and may cost some performance on large-batch workloads.Summary
Updates the CuTe DSL NVFP4 MoE kernels to accept weights as a list of tensors split along the expert dimension (up to 4), in addition to the existing single-tensor layout. This lands the DWDP (Distributed Weight Data Parallelism) support that the CUTLASS/TRT-LLM side already has.
w1_weight/w1_weight_sf/w1_alphaand thew2_*counterparts now acceptUnion[Tensor, List[Tensor]]b_tensor_l_sizes; the right B tensor is selected from the runtime expert index on the kernel sidePorts the approach from NVIDIA/TensorRT-LLM#12136.
🔍 Related Issues
#3036
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit