Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
- Gather: Uses LDGSTS to gather A directly using token_id_mapping, no moe_permute needed
"""

from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

import cutlass
import cutlass.cute as cute
Expand Down Expand Up @@ -191,15 +191,14 @@ def _get_compiled_gather_kernel(
permuted_m: int,
n: int, # This is 2*intermediate_size
k: int,
num_experts: int,
# Tensor pointers (runtime parameters - NOT in cache key)
a_ptr,
b_ptr,
b_ptr, # tuple of pointers
a_sf_ptr,
b_sf_ptr,
b_sf_ptr, # tuple of pointers
c_ptr,
c_sf_ptr,
alpha_ptr,
alpha_ptr, # tuple of pointers
tile_idx_ptr,
mn_limit_ptr,
token_id_ptr,
Expand All @@ -221,6 +220,7 @@ def _get_compiled_gather_kernel(
vectorized_f32: bool,
raster_along_m: bool,
enable_pdl: bool = True,
b_tensor_l_sizes: Optional[Tuple[int, ...]] = None,
):
"""Get or compile the gather grouped GEMM with SwiGLU kernel.

Expand All @@ -234,10 +234,14 @@ def _get_compiled_gather_kernel(
This matches TRT-LLM's approach where the same compiled kernel can be
reused for different problem sizes, significantly reducing JIT compilation
overhead during autotuning.

Supports multiple B weight tensors via b_tensor_l_sizes parameter.
When b_tensor_l_sizes is provided, b_ptr/b_sf_ptr/alpha_ptr are tuples.
"""
global _gather_kernel_cache

# Cache key includes dtype and tactic parameters, NOT problem dimensions
# Also includes b_tensor_l_sizes since kernel is specialized per multi-B config
cache_key = (
ab_dtype,
sf_dtype,
Expand All @@ -250,6 +254,7 @@ def _get_compiled_gather_kernel(
vectorized_f32,
raster_along_m,
enable_pdl,
b_tensor_l_sizes,
)

if cache_key not in _gather_kernel_cache:
Expand All @@ -262,16 +267,17 @@ def _get_compiled_gather_kernel(
topk=topk,
raster_along_m=raster_along_m,
enable_pdl=enable_pdl,
b_tensor_l_sizes=b_tensor_l_sizes,
)

# Compile with runtime parameters - they can vary across calls
# Order must match wrapper signature:
# (a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, alpha_ptr,
# (a_ptr, b_ptr_tuple, a_sf_ptr, b_sf_ptr_tuple, c_ptr, c_sf_ptr, alpha_ptr_tuple,
# tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, token_id_mapping_ptr,
# num_non_exiting_tiles_ptr, global_sf_ptr, orig_m, m, n, k, l,
# num_non_exiting_tiles_ptr, norm_const_ptr, orig_m, m, n, k, l,
# tile_size, scaling_vector_size, max_active_clusters, stream)
compiled_gemm = cute.compile(
gemm.wrapper,
num_experts = sum(b_tensor_l_sizes)
compile_args = [
a_ptr,
b_ptr,
a_sf_ptr,
Expand All @@ -289,6 +295,11 @@ def _get_compiled_gather_kernel(
n,
k,
num_experts,
]

compiled_gemm = cute.compile(
gemm.wrapper,
*compile_args,
tile_size=tile_size,
scaling_vector_size=sf_vec_size,
max_active_clusters=max_active_clusters,
Expand All @@ -302,10 +313,10 @@ def _get_compiled_gather_kernel(

def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
a: torch.Tensor,
b: torch.Tensor,
b: Union[torch.Tensor, List[torch.Tensor]],
a_scale: torch.Tensor,
b_scale: torch.Tensor,
alpha: torch.Tensor,
b_scale: Union[torch.Tensor, List[torch.Tensor]],
alpha: Union[torch.Tensor, List[torch.Tensor]],
tile_idx_to_expert_idx: torch.Tensor,
tile_idx_to_mn_limit: torch.Tensor,
token_id_mapping: torch.Tensor,
Expand Down Expand Up @@ -406,14 +417,27 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
... topk=topk,
... ) # out shape: (valid_m, intermediate_dim)
"""
# Normalize to lists for multi-B support
b_list = [b] if isinstance(b, torch.Tensor) else list(b)
b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else list(b_scale)
alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else list(alpha)

# Validate multi-B inputs
assert len(b_list) > 0, "Weight tensor list must not be empty"
assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
assert len(b_list) == len(b_scale_list) == len(alpha_list), (
f"b, b_scale, alpha lists must have same length: "
f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
)

# Validate inputs
assert a.device.type == "cuda", "Input tensors must be on CUDA device"
assert b.device.type == "cuda", "Input tensors must be on CUDA device"
assert b_list[0].device.type == "cuda", "Input tensors must be on CUDA device"

# Get dimensions
seq_len = a.shape[0]
num_experts = b.shape[0]
n = b.shape[1] # This is 2*intermediate_size
num_experts = sum(bi.size(0) for bi in b_list)
n = b_list[0].shape[1] # This is 2*intermediate_size
k = a.shape[1]
Comment on lines +420 to 441
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate each split tensor, not just the list length.

The new checks still allow invalid multi-B configurations: an empty first split, a later split on a different device, a later split with different (N, K), or an alpha[i] whose length does not match b[i].size(0). Those all flow into the kernel, which assumes split 0 exists and reuses the first split's shape for every other split.

Proposed fix
-    assert len(b_list) > 0, "Weight tensor list must not be empty"
-    assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
-    assert len(b_list) == len(b_scale_list) == len(alpha_list), (
-        f"b, b_scale, alpha lists must have same length: "
-        f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
-    )
+    if len(b_list) == 0:
+        raise ValueError("Weight tensor list must not be empty")
+    if len(b_list) > 4:
+        raise ValueError(f"Maximum 4 weight tensors supported, got {len(b_list)}")
+    if len(b_list) != len(b_scale_list) or len(b_list) != len(alpha_list):
+        raise ValueError(
+            "b, b_scale, and alpha must contain the same number of splits"
+        )
+
+    ref_nk = b_list[0].shape[1:]
+    for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+        if bi.size(0) == 0:
+            raise ValueError(f"b[{i}] must contain at least one expert")
+        if bi.device != a.device or bsi.device != a.device or ai.device != a.device:
+            raise ValueError(
+                f"All split tensors must be on {a.device}; "
+                f"got b[{i}]={bi.device}, b_scale[{i}]={bsi.device}, alpha[{i}]={ai.device}"
+            )
+        if bi.shape[1:] != ref_nk:
+            raise ValueError(
+                f"All B splits must share the same (N, K); "
+                f"expected {ref_nk}, got {bi.shape[1:]} for b[{i}]"
+            )
+        if ai.numel() != bi.size(0):
+            raise ValueError(
+                f"alpha[{i}] must have {bi.size(0)} entries, got {ai.numel()}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 418 - 439, Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.

if ab_dtype == "float4_e2m1fn":
k = k * 2 # FP4 is packed 2 elements per byte
Expand Down Expand Up @@ -500,19 +524,16 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
# Get tile_size from mma_tiler_mn
tile_size = mma_tiler_mn[0]

# Compute b_tensor_l_sizes for multi-B support
b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list)

# Create raw pointers (TRT-LLM style) - allows same compiled kernel for different sizes
a_ptr = make_ptr(
ab_dtype_cutlass, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
b_ptr = make_ptr(
ab_dtype_cutlass, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
a_sf_ptr = make_ptr(
sf_dtype_cutlass, a_scale.data_ptr(), cute.AddressSpace.gmem, assumed_align=16
)
b_sf_ptr = make_ptr(
sf_dtype_cutlass, b_scale.data_ptr(), cute.AddressSpace.gmem, assumed_align=16
)
c_ptr = make_ptr(
c_dtype_cutlass, out.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
Expand All @@ -531,7 +552,24 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
c_sf_ptr = None
norm_const_ptr = None

alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), cute.AddressSpace.gmem)
# Create pointer tuples for B tensors
b_ptr = tuple(
make_ptr(
ab_dtype_cutlass, bi.data_ptr(), cute.AddressSpace.gmem, assumed_align=32
)
for bi in b_list
)
b_sf_ptr = tuple(
make_ptr(
sf_dtype_cutlass, bsi.data_ptr(), cute.AddressSpace.gmem, assumed_align=16
)
for bsi in b_scale_list
)
alpha_ptr = tuple(
make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem)
for ai in alpha_list
)

tile_idx_ptr = make_ptr(
cutlass.Int32, tile_idx_to_expert_idx.data_ptr(), cute.AddressSpace.gmem
)
Expand All @@ -549,15 +587,12 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)

# Get or compile the kernel (cached by dtype and tactic parameters)
# Get or compile the kernel
compiled_gemm = _get_compiled_gather_kernel(
# Runtime parameters (problem dimensions)
orig_m=seq_len,
permuted_m=permuted_m,
n=n,
k=k,
num_experts=num_experts,
# Tensor pointers (order must match wrapper signature)
a_ptr=a_ptr,
b_ptr=b_ptr,
a_sf_ptr=a_sf_ptr,
Expand All @@ -572,11 +607,9 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
norm_const_ptr=norm_const_ptr,
max_active_clusters=max_active_clusters,
stream=stream,
# Dtype parameters (compile-time, in cache key)
ab_dtype=ab_dtype,
sf_dtype=sf_dtype,
c_dtype=c_dtype,
# Tactic parameters (compile-time, cached)
sf_vec_size=sf_vec_size,
tile_size=tile_size,
topk=topk,
Expand All @@ -585,14 +618,12 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
vectorized_f32=vectorized_f32,
raster_along_m=raster_along_m,
enable_pdl=enable_pdl,
b_tensor_l_sizes=b_tensor_l_sizes,
)

# Execute kernel with runtime parameters
# Order must match wrapper signature:
# (a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, alpha_ptr,
# tile_idx_ptr, mn_limit_ptr, token_id_ptr, num_tiles_ptr, global_sf_ptr,
# orig_m, m, n, k, l, stream)
compiled_gemm(
# Execute kernel
num_experts = sum(b_tensor_l_sizes)
exec_args = [
a_ptr,
b_ptr,
a_sf_ptr,
Expand All @@ -609,8 +640,8 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4(
permuted_m,
n,
k,
num_experts,
stream=stream,
)
num_experts, # l
]
compiled_gemm(*exec_args, stream=stream)

return out, out_scale if generate_sfc else None
Loading
Loading