From 038bf93825ea2a2eb1e1888df0f50da69c783bdf Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Mon, 13 Apr 2026 04:54:36 +0000 Subject: [PATCH 1/3] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as either a single tensor or a list of up to 4 tensors split along the expert dimension. The compiled kernel is specialized per multi-B config via b_tensor_l_sizes, with kernel-side branching selecting the right B tensor from the runtime expert index. Also adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline. Co-Authored-By: Claude Opus 4.6 (1M context) --- ...guous_gather_grouped_gemm_swiglu_fusion.py | 863 ++++++++++++++--- ...contiguous_grouped_gemm_finalize_fusion.py | 870 +++++++++++++++--- ...guous_gather_grouped_gemm_swiglu_fusion.py | 95 +- ...contiguous_grouped_gemm_finalize_fusion.py | 94 +- flashinfer/fused_moe/cute_dsl/fused_moe.py | 101 +- flashinfer/fused_moe/cute_dsl/tuner.py | 15 +- tests/moe/test_cute_dsl_fused_moe.py | 407 ++++++++ 7 files changed, 2037 insertions(+), 408 deletions(-) diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index f8c50c624f..f2edf7f968 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -315,6 +315,10 @@ class BlockScaledContiguousGatherGroupedGemmKernel: """This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result). + Supports multiple B weight tensors for DWDP (Distributed Weight Data Parallelism). + When b_tensor_l_sizes is provided, the kernel selects from multiple B tensors + based on expert index at runtime. + The computation flow: 1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) 2. SwiGLU: C = up * silu(gate), extracted from interleaved acc with granularity=64 @@ -395,6 +399,9 @@ class BlockScaledContiguousGatherGroupedGemmKernel: ... ) """ + # Maximum number of B tensors supported (must match kernel's const_expr branches) + MAX_B_TENSORS = 4 + def __init__( self, sf_vec_size: int, @@ -404,6 +411,7 @@ def __init__( topk: cutlass.Int64, raster_along_m: bool = False, enable_pdl: bool = True, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with gather operation and SwiGLU fusion. @@ -524,6 +532,26 @@ def __init__( self.vectorized_f32 = vectorized_f32 + # Multi-B tensor configuration + if b_tensor_l_sizes is None: + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -724,17 +752,17 @@ def _setup_attributes(self): def __call__( self, a: cute.Tensor, - b: cute.Tensor, + b: Union[cute.Tensor, Tuple[cute.Tensor, ...]], c: cute.Tensor, sfa: cute.Tensor, - sfb: cute.Tensor, + sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]], sfc_tensor: Optional[cute.Tensor], norm_const_tensor: Optional[cute.Tensor], tile_idx_to_expert_idx: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, token_id_mapping_tensor: cute.Tensor, num_non_exiting_tiles: cute.Tensor, - alpha: cute.Tensor, + alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, @@ -802,11 +830,14 @@ def __call__( """ # Setup static attributes before smem/grid/tma computation self.a_dtype: Type[cutlass.Numeric] = a.element_type - self.b_dtype: Type[cutlass.Numeric] = b.element_type + # Handle tuple of B tensors + b_tuple = b if isinstance(b, tuple) else (b,) + sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,) + self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type self.c_dtype: Type[cutlass.Numeric] = c.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode() self.c_layout = utils.LayoutEnum.from_tensor(c) # Check if input data types are compatible with MMA instruction @@ -816,10 +847,32 @@ def __call__( # Setup attributes that dependent on gemm inputs self._setup_attributes() - # Setup sfb tensor by filling B tensor to scale factor atom layout + # Setup sfb tensors by filling B tensor to scale factor atom layout # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) - sfb = cute.make_tensor(sfb.iterator, sfb_layout) + # Create layout for each B tensor (use const_expr, not loop) + sfb_layout_0 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[0].shape, self.sf_vec_size + ) + sfb_tensor_0 = cute.make_tensor(sfb_tuple[0].iterator, sfb_layout_0) + sfb_tensors = [sfb_tensor_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + sfb_layout_1 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[1].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[1].iterator, sfb_layout_1)) + if cutlass.const_expr(self.num_b_tensors >= 3): + sfb_layout_2 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[2].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[2].iterator, sfb_layout_2)) + if cutlass.const_expr(self.num_b_tensors >= 4): + sfb_layout_3 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[3].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[3].iterator, sfb_layout_3)) + sfb_tuple = tuple(sfb_tensors) + # Backward compat alias + sfb = sfb_tuple[0] # Setup sfc tensor by filling C tensor to scale factor atom layout self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None @@ -851,59 +904,99 @@ def __call__( ) atom_thr_size = cute.size(tiled_mma.thr_id.shape) - # Setup TMA load for B + # Setup TMA ops (shared across all B tensors) b_op = sm100_utils.cluster_shape_to_tma_atom_B( self.cluster_shape_mn, tiled_mma.thr_id ) - b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, - b, - b_smem_layout, - self.mma_tiler, - tiled_mma, - self.cluster_layout_vmnk.shape, - ) - - # Setup TMA load for SFB sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( self.cluster_shape_mn, tiled_mma.thr_id ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) sfb_smem_layout = cute.slice_( self.sfb_smem_layout_staged, (None, None, None, 0) ) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, - sfb, - sfb_smem_layout, - self.mma_tiler_sfb, - tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, - internal_type=cutlass.Int16, - ) - # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) - # logical blocks for SFB when cta_tile_shape_n=192. - if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): - x = tma_tensor_sfb.stride[0][1] - y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + # Helper to create TMA for one B tensor + def _make_tma_b(b_tensor, sfb_tensor): + atom_b, tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + atom_sfb, tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + # This modifies the layout to handle overlapping 256x(# of scale factors + # for a single column of B (nNSF)) logical blocks for SFB when + # cta_tile_shape_n=192. + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tensor_sfb.stride[0][1] + y = cute.ceil_div(tensor_sfb.shape[0][1], 4) + new_shape = ( + (tensor_sfb.shape[0][0], ((2, 2), y)), + tensor_sfb.shape[1], + tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tensor_sfb.stride[1], + tensor_sfb.stride[2], + ) + tensor_sfb = cute.make_tensor( + tensor_sfb.iterator, cute.make_layout(new_shape, stride=new_stride) + ) + return atom_b, tensor_b, atom_sfb, tensor_sfb - new_shape = ( - (tma_tensor_sfb.shape[0][0], ((2, 2), y)), - tma_tensor_sfb.shape[1], - tma_tensor_sfb.shape[2], + # Create TMA for all B tensors (use const_expr, not loop) + atom_b_0, tensor_b_0, atom_sfb_0, tensor_sfb_0 = _make_tma_b( + b_tuple[0], sfb_tuple[0] + ) + tma_atoms_b = [atom_b_0] + tma_tensors_b = [tensor_b_0] + tma_atoms_sfb = [atom_sfb_0] + tma_tensors_sfb = [tensor_sfb_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + atom_b_1, tensor_b_1, atom_sfb_1, tensor_sfb_1 = _make_tma_b( + b_tuple[1], sfb_tuple[1] ) - # Use right multiplication for ScaledBasis (3 * x instead of x * 3) - x_times_3 = 3 * x - new_stride = ( - (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), - tma_tensor_sfb.stride[1], - tma_tensor_sfb.stride[2], + tma_atoms_b.append(atom_b_1) + tma_tensors_b.append(tensor_b_1) + tma_atoms_sfb.append(atom_sfb_1) + tma_tensors_sfb.append(tensor_sfb_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + atom_b_2, tensor_b_2, atom_sfb_2, tensor_sfb_2 = _make_tma_b( + b_tuple[2], sfb_tuple[2] ) - tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor( - tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout + tma_atoms_b.append(atom_b_2) + tma_tensors_b.append(tensor_b_2) + tma_atoms_sfb.append(atom_sfb_2) + tma_tensors_sfb.append(tensor_sfb_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + atom_b_3, tensor_b_3, atom_sfb_3, tensor_sfb_3 = _make_tma_b( + b_tuple[3], sfb_tuple[3] ) + tma_atoms_b.append(atom_b_3) + tma_tensors_b.append(tensor_b_3) + tma_atoms_sfb.append(atom_sfb_3) + tma_tensors_sfb.append(tensor_sfb_3) + tma_atoms_b = tuple(tma_atoms_b) + tma_tensors_b = tuple(tma_tensors_b) + tma_atoms_sfb = tuple(tma_atoms_sfb) + tma_tensors_sfb = tuple(tma_tensors_sfb) + + # Handle alpha tuple (convert to tuple if single tensor) + alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) @@ -1052,11 +1145,11 @@ class SharedStorage2cta: tiled_mma, tiled_mma_sfb, a, - tma_atom_b, - tma_tensor_b, + tma_atoms_b, # Tuple of TMA atoms for B + tma_tensors_b, # Tuple of TMA tensors for B sfa, - tma_atom_sfb, - tma_tensor_sfb, + tma_atoms_sfb, # Tuple of TMA atoms for SFB + tma_tensors_sfb, # Tuple of TMA tensors for SFB tma_atom_c, tma_tensor_c, sfc_tensor, @@ -1065,7 +1158,7 @@ class SharedStorage2cta: tile_idx_to_mn_limit, token_id_mapping_tensor, num_non_exiting_tiles, - alpha, + alpha_tuple, self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, self.a_smem_layout_staged, @@ -1138,11 +1231,11 @@ def kernel( tiled_mma: cute.TiledMma, tiled_mma_sfb: cute.TiledMma, mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, + tma_atoms_b: Tuple[cute.CopyAtom, ...], + mB_nkl_tuple: Tuple[cute.Tensor, ...], mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, + tma_atoms_sfb: Tuple[cute.CopyAtom, ...], + mSFB_nkl_tuple: Tuple[cute.Tensor, ...], tma_atom_c: cute.CopyAtom, mC_mnl: cute.Tensor, mSFC_mnl: Optional[cute.Tensor], @@ -1151,7 +1244,7 @@ def kernel( tile_idx_to_mn_limit: cute.Tensor, token_id_mapping_tensor: cute.Tensor, num_non_exiting_tiles: cute.Tensor, - alpha: cute.Tensor, + alpha_tuple: Tuple[cute.Tensor, ...], cluster_layout_vmnk: cute.Layout, cluster_layout_sfb_vmnk: cute.Layout, a_smem_layout_staged: cute.ComposedLayout, @@ -1173,8 +1266,18 @@ def kernel( # Prefetch tma desc # if warp_idx == self.tma_b_warp_id: - cpasync.prefetch_descriptor(tma_atom_b) - cpasync.prefetch_descriptor(tma_atom_sfb) + # Prefetch TMA descriptors for all B tensors using const_expr conditions + cpasync.prefetch_descriptor(tma_atoms_b[0]) + cpasync.prefetch_descriptor(tma_atoms_sfb[0]) + if cutlass.const_expr(self.num_b_tensors >= 2): + cpasync.prefetch_descriptor(tma_atoms_b[1]) + cpasync.prefetch_descriptor(tma_atoms_sfb[1]) + if cutlass.const_expr(self.num_b_tensors >= 3): + cpasync.prefetch_descriptor(tma_atoms_b[2]) + cpasync.prefetch_descriptor(tma_atoms_sfb[2]) + if cutlass.const_expr(self.num_b_tensors >= 4): + cpasync.prefetch_descriptor(tma_atoms_b[3]) + cpasync.prefetch_descriptor(tma_atoms_sfb[3]) cpasync.prefetch_descriptor(tma_atom_c) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 @@ -1348,10 +1451,30 @@ def kernel( cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), (None, None, None), ) - # (bN, bK, loopN, loopK, loopL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + # (bN, bK, loopN, loopK, loopL) - Use const_expr conditions for tuple indexing + gB_nkl_0 = cute.local_tile( + mB_nkl_tuple[0], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gB_nkl_1 = cute.local_tile( + mB_nkl_tuple[1], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gB_nkl_2 = cute.local_tile( + mB_nkl_tuple[2], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gB_nkl_3 = cute.local_tile( + mB_nkl_tuple[3], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( @@ -1360,12 +1483,30 @@ def kernel( (None, None, None), ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, + # (bN, bK, RestN, RestK, RestL) - Use const_expr conditions for tuple indexing + gSFB_nkl_0 = cute.local_tile( + mSFB_nkl_tuple[0], cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gSFB_nkl_1 = cute.local_tile( + mSFB_nkl_tuple[1], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gSFB_nkl_2 = cute.local_tile( + mSFB_nkl_tuple[2], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gSFB_nkl_3 = cute.local_tile( + mSFB_nkl_tuple[3], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) gToken_ml = cute.local_tile( token_id_mapping_tensor, @@ -1384,45 +1525,112 @@ def kernel( # thr_mma = tiled_mma.get_slice(mma_tile_coord_v) thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - tCgB = thr_mma.partition_B(gB_nkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - const_expr conditions + tCgB_0 = thr_mma.partition_B(gB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgB_1 = thr_mma.partition_B(gB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgB_2 = thr_mma.partition_B(gB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgB_3 = thr_mma.partition_B(gB_nkl_3) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - const_expr conditions + tCgSFB_0 = thr_mma_sfb.partition_B(gSFB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgSFB_1 = thr_mma_sfb.partition_B(gSFB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgSFB_2 = thr_mma_sfb.partition_B(gSFB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgSFB_3 = thr_mma_sfb.partition_B(gSFB_nkl_3) # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) tCgC = thr_mma.partition_C(gC_mnl) # # Partition global/shared tensor for TMA load B # - # TMA load B partition_S/D b_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + sB_grouped = cute.group_modes(sB, 0, 3) + sSFB_grouped = cute.group_modes(sSFB, 0, 3) + + # TMA partition for B tensor 0 # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, + # ((atom_v, rest_v), loopN, loopK, loopL) + tBsB_0, tBgB_0 = cpasync.tma_partition( + tma_atoms_b[0], block_in_cluster_coord_vmnk[1], b_cta_layout, - cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), - ) - - # TMA load SFB partition_S/D - sfb_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + sB_grouped, + cute.group_modes(tCgB_0, 0, 3), ) # ((atom_v, rest_v), STAGE) # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( - tma_atom_sfb, + tBsSFB_0, tBgSFB_0 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[0], block_in_cluster_coord_sfb_vmnk[1], sfb_cta_layout, - cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), + sSFB_grouped, + cute.group_modes(tCgSFB_0, 0, 3), ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) + tBsSFB_0 = cute.filter_zeros(tBsSFB_0) + tBgSFB_0 = cute.filter_zeros(tBgSFB_0) + + # TMA partition for B tensor 1 (tBsB shared memory partition is same for all, use _ to ignore) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgB_1 = cpasync.tma_partition( + tma_atoms_b[1], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_1, 0, 3), + ) + _, tBgSFB_1 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[1], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_1, 0, 3), + ) + tBgSFB_1 = cute.filter_zeros(tBgSFB_1) + + # TMA partition for B tensor 2 + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgB_2 = cpasync.tma_partition( + tma_atoms_b[2], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_2, 0, 3), + ) + _, tBgSFB_2 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[2], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_2, 0, 3), + ) + tBgSFB_2 = cute.filter_zeros(tBgSFB_2) + + # TMA partition for B tensor 3 + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgB_3 = cpasync.tma_partition( + tma_atoms_b[3], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_3, 0, 3), + ) + _, tBgSFB_3 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[3], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_3, 0, 3), + ) + tBgSFB_3 = cute.filter_zeros(tBgSFB_3) # # Partition shared/tensor memory tensor for TiledMMA_A/B/C @@ -1980,22 +2188,13 @@ def kernel( tile_info[1], tile_info[2], ) - # - # Slice to per mma tile index - # - # ((atom_v, rest_v), loopK) - tBgB_slice = tBgB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) - ] + expert_idx = mma_tile_coord_mnl[2] # Apply SFB slicing hack when cta_tile_shape_n=64 slice_n = mma_tile_coord_mnl[1] if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): slice_n = mma_tile_coord_mnl[1] // 2 - # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt b_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) @@ -2009,31 +2208,307 @@ def kernel( for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007 # Conditionally wait for B buffer empty b_pipeline.producer_acquire(b_producer_state, peek_ab_empty_status) - - tBgB_k = tBgB_slice[(None, b_producer_state.count)] - tBgSFB_k = tBgSFB_slice[(None, b_producer_state.count)] - tBsB_pipe = tBsB[(None, b_producer_state.index)] - tBsSFB_pipe = tBsSFB[(None, b_producer_state.index)] - + tBsB_pipe = tBsB_0[(None, b_producer_state.index)] + tBsSFB_pipe = tBsSFB_0[(None, b_producer_state.index)] tma_bar = b_pipeline.producer_get_barrier(b_producer_state) - # TMA load B - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=b_full_mcast_mask, - ) - - # TMA load SFB - cute.copy( - tma_atom_sfb, - tBgSFB_k, - tBsSFB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=sfb_full_mcast_mask, - ) + # Select correct B tensor based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 1): + # Single B tensor - original logic + tBgB_slice = tBgB_0[ + (None, mma_tile_coord_mnl[1], None, expert_idx) + ] + tBgSFB_slice = tBgSFB_0[(None, slice_n, None, expert_idx)] + cute.copy( + tma_atoms_b[0], + tBgB_slice[(None, b_producer_state.count)], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_slice[(None, b_producer_state.count)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # Multi-B tensor - select based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[ + ( + None, + slice_n, + b_producer_state.count, + local_l_0, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[ + ( + None, + slice_n, + b_producer_state.count, + local_l_1, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif cutlass.const_expr(self.num_b_tensors == 3): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[ + ( + None, + slice_n, + b_producer_state.count, + local_l_0, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[ + ( + None, + slice_n, + b_producer_state.count, + local_l_1, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[ + ( + None, + slice_n, + b_producer_state.count, + local_l_2, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # 4 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[ + ( + None, + slice_n, + b_producer_state.count, + local_l_0, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[ + ( + None, + slice_n, + b_producer_state.count, + local_l_1, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[3]: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[ + ( + None, + slice_n, + b_producer_state.count, + local_l_2, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_3 = expert_idx - self.b_tensor_l_offsets[3] + cute.copy( + tma_atoms_b[3], + tBgB_3[ + ( + None, + mma_tile_coord_mnl[1], + b_producer_state.count, + local_l_3, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[3], + tBgSFB_3[ + ( + None, + slice_n, + b_producer_state.count, + local_l_3, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 b_producer_state.advance() @@ -2496,9 +2971,49 @@ def kernel( # # Get alpha for current group # - expert_idx = mma_tile_coord_mnl[2] - alpha_val = alpha[expert_idx] + + # Select alpha from correct tensor based on expert_idx + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + if cutlass.const_expr(self.num_b_tensors == 1): + pass # Already initialized above + elif cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx >= self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[1][ + expert_idx - self.b_tensor_l_offsets[1] + ] + elif cutlass.const_expr(self.num_b_tensors == 3): + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][ + expert_idx - self.b_tensor_l_offsets[1] + ] + elif expert_idx >= self.b_tensor_l_offsets[2]: + alpha_val = alpha_tuple[2][ + expert_idx - self.b_tensor_l_offsets[2] + ] + else: + # 4 B tensors + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][ + expert_idx - self.b_tensor_l_offsets[1] + ] + elif ( + expert_idx >= self.b_tensor_l_offsets[2] + and expert_idx < self.b_tensor_l_offsets[3] + ): + alpha_val = alpha_tuple[2][ + expert_idx - self.b_tensor_l_offsets[2] + ] + elif expert_idx >= self.b_tensor_l_offsets[3]: + alpha_val = alpha_tuple[3][ + expert_idx - self.b_tensor_l_offsets[3] + ] # # Slice to per mma tile index @@ -3497,12 +4012,12 @@ def can_implement( def wrapper( self, a_ptr: cute.Pointer, - b_ptr: cute.Pointer, + b_ptr_tuple: Tuple[cute.Pointer, ...], a_sf_ptr: cute.Pointer, - b_sf_ptr: cute.Pointer, + b_sf_ptr_tuple: Tuple[cute.Pointer, ...], c_ptr: cute.Pointer, c_sf_ptr: cute.Pointer, - alpha_ptr: cute.Pointer, + alpha_ptr_tuple: Tuple[cute.Pointer, ...], tile_idx_to_group_idx_ptr: cute.Pointer, tile_idx_to_mn_limit_ptr: cute.Pointer, token_id_mapping_ptr: cute.Pointer, @@ -3512,43 +4027,113 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, - l: cutlass.Int64, # noqa: E741 tile_size: cutlass.Constexpr, scaling_vector_size: cutlass.Constexpr, max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): + """Unified wrapper supporting both single-B and multi-B tensors. + + B tensors are always passed as tuples (length 1 for single-B). + L sizes are configured via b_tensor_l_sizes in __init__. + """ scale_k = k // scaling_vector_size interm_size = n // 2 num_tiles = m // tile_size + total_l = self.b_tensor_l_offsets[self.num_b_tensors] + a = cute.make_tensor( a_ptr, layout=cute.make_ordered_layout((orig_m, k, 1), order=(1, 0, 2)) ) - b = cute.make_tensor( - b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2)) - ) a_sf = cute.make_tensor( a_sf_ptr, layout=cute.make_ordered_layout((orig_m, scale_k, 1), order=(1, 0, 2)), ) - b_sf = cute.make_tensor( - b_sf_ptr, - layout=cute.make_ordered_layout( - (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) - ), - ) c = cute.make_tensor( c_ptr, layout=cute.make_ordered_layout((m, interm_size, 1), order=(1, 0, 2)) ) c_sf = cute.make_tensor( c_sf_ptr, layout=cute.make_ordered_layout( - (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l), + (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l), order=(2, 1, 4, 0, 3, 5), ), ) - alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + # Create B and alpha tensors using const_expr conditions + l_0 = self.b_tensor_l_sizes[0] + alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) + b_0 = cute.make_tensor( + b_ptr_tuple[0], + layout=cute.make_ordered_layout((n, k, l_0), order=(1, 0, 2)), + ) + b_sf_0 = cute.make_tensor( + b_sf_ptr_tuple[0], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_0), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple = [b_0] + b_sf_tuple = [b_sf_0] + alpha_tuple = [alpha_0] + + if cutlass.const_expr(self.num_b_tensors >= 2): + l_1 = self.b_tensor_l_sizes[1] + alpha_1 = cute.make_tensor( + alpha_ptr_tuple[1], layout=cute.make_layout((l_1,)) + ) + b_1 = cute.make_tensor( + b_ptr_tuple[1], + layout=cute.make_ordered_layout((n, k, l_1), order=(1, 0, 2)), + ) + b_sf_1 = cute.make_tensor( + b_sf_ptr_tuple[1], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_1), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_1) + b_sf_tuple.append(b_sf_1) + alpha_tuple.append(alpha_1) + + if cutlass.const_expr(self.num_b_tensors >= 3): + l_2 = self.b_tensor_l_sizes[2] + alpha_2 = cute.make_tensor( + alpha_ptr_tuple[2], layout=cute.make_layout((l_2,)) + ) + b_2 = cute.make_tensor( + b_ptr_tuple[2], + layout=cute.make_ordered_layout((n, k, l_2), order=(1, 0, 2)), + ) + b_sf_2 = cute.make_tensor( + b_sf_ptr_tuple[2], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_2), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_2) + b_sf_tuple.append(b_sf_2) + alpha_tuple.append(alpha_2) + + if cutlass.const_expr(self.num_b_tensors >= 4): + l_3 = self.b_tensor_l_sizes[3] + alpha_3 = cute.make_tensor( + alpha_ptr_tuple[3], layout=cute.make_layout((l_3,)) + ) + b_3 = cute.make_tensor( + b_ptr_tuple[3], + layout=cute.make_ordered_layout((n, k, l_3), order=(1, 0, 2)), + ) + b_sf_3 = cute.make_tensor( + b_sf_ptr_tuple[3], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_3), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_3) + b_sf_tuple.append(b_sf_3) + alpha_tuple.append(alpha_3) tile_idx_to_group_idx = cute.make_tensor( tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) @@ -3566,17 +4151,17 @@ def wrapper( return self( a, - b, + tuple(b_tuple), c, a_sf, - b_sf, + tuple(b_sf_tuple), c_sf, global_sf, tile_idx_to_group_idx, tile_idx_to_mn_limit, token_id_mapping, num_non_exiting_tiles, - alpha, + tuple(alpha_tuple), max_active_clusters=max_active_clusters, stream=stream, epilogue_op=epilogue_op, diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index e07fab4eb6..df9c78dcbb 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -26,7 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import Tuple, Type, Union +from typing import Optional, Tuple, Type, Union import cuda.bindings.driver as cuda @@ -357,6 +357,9 @@ class Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel: ... ) """ + # Maximum number of B tensors supported (must match kernel's const_expr branches) + MAX_B_TENSORS = 4 + def __init__( self, sf_vec_size: int, @@ -365,6 +368,7 @@ def __init__( use_blkred: bool = False, raster_along_m: bool = False, enable_pdl: bool = True, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel. @@ -448,6 +452,26 @@ def __init__( # TMEM offset for final accumulator self.tmem_final_offset = 384 + # Multi-B tensor configuration + if b_tensor_l_sizes is None: + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) + def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -632,14 +656,14 @@ def _setup_attributes(self): def __call__( self, a: cute.Tensor, - b: cute.Tensor, + b: Union[cute.Tensor, Tuple[cute.Tensor, ...]], out: cute.Tensor, sfa: cute.Tensor, - sfb: cute.Tensor, + sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]], tile_idx_to_expert_idx: cute.Tensor, num_non_exiting_tiles: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, - alpha: cute.Tensor, + alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]], max_active_clusters: cutlass.Constexpr, stream: cuda.CUstream, permuted_idx_to_expanded_idx: cute.Tensor, @@ -655,21 +679,21 @@ def __call__( :param a: Input tensor A :type a: cute.Tensor - :param b: Input tensor B - :type b: cute.Tensor + :param b: Input tensor B (single or tuple for multi-B DWDP) + :type b: Union[cute.Tensor, Tuple[cute.Tensor, ...]] :param out: Finalized output tensor (shape [seq_len, n]) :type out: cute.Tensor :param sfa: Scale factor tensor A :type sfa: cute.Tensor - :param sfb: Scale factor tensor B - :type sfb: cute.Tensor + :param sfb: Scale factor tensor B (single or tuple for multi-B DWDP) + :type sfb: Union[cute.Tensor, Tuple[cute.Tensor, ...]] :param tile_idx_to_expert_idx: Mapping from tile index to expert ID, shape (permuted_m/cta_tile_m,) where cta_tile_m is the CTA tile M size :type tile_idx_to_expert_idx: cute.Tensor :param num_non_exiting_tiles: Number of valid tiles (valid_m/cta_tile_m), shape (1,) :type num_non_exiting_tiles: cute.Tensor - :param alpha: Alpha tensor for each group - :type alpha: cute.Tensor + :param alpha: Alpha tensor for each group (single or tuple for multi-B DWDP) + :type alpha: Union[cute.Tensor, Tuple[cute.Tensor, ...]] :param max_active_clusters: Maximum number of active clusters :type max_active_clusters: cutlass.Constexpr :param stream: CUDA stream for asynchronous execution @@ -683,13 +707,16 @@ def __call__( :raises TypeError: If input data types are incompatible with the MMA instruction. """ # Setup static attributes before smem/grid/tma computation + # Handle tuple of B tensors + b_tuple = b if isinstance(b, tuple) else (b,) + sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,) self.a_dtype: Type[cutlass.Numeric] = a.element_type - self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type self.out_dtype: Type[cutlass.Numeric] = out.element_type self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type self.final_scale_dtype: Type[cutlass.Numeric] = token_final_scales.element_type self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() - self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode() self.gemm_output_layout = utils.LayoutEnum.ROW_MAJOR self.topK = token_final_scales.shape[1] @@ -704,9 +731,30 @@ def __call__( sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(a.shape, self.sf_vec_size) sfa = cute.make_tensor(sfa.iterator, sfa_layout) + # Setup sfb tensors by filling B tensor to scale factor atom layout # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) - sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) - sfb = cute.make_tensor(sfb.iterator, sfb_layout) + # Create layout for each B tensor (use const_expr, not loop) + sfb_layout_0 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[0].shape, self.sf_vec_size + ) + sfb_tensor_0 = cute.make_tensor(sfb_tuple[0].iterator, sfb_layout_0) + sfb_tensors = [sfb_tensor_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + sfb_layout_1 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[1].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[1].iterator, sfb_layout_1)) + if cutlass.const_expr(self.num_b_tensors >= 3): + sfb_layout_2 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[2].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[2].iterator, sfb_layout_2)) + if cutlass.const_expr(self.num_b_tensors >= 4): + sfb_layout_3 = blockscaled_utils.tile_atom_to_shape_SF( + b_tuple[3].shape, self.sf_vec_size + ) + sfb_tensors.append(cute.make_tensor(sfb_tuple[3].iterator, sfb_layout_3)) + sfb_tuple = tuple(sfb_tensors) tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( self.a_dtype, @@ -743,20 +791,6 @@ def __call__( self.cluster_layout_vmnk.shape, ) - # Setup TMA load for B - b_op = sm100_utils.cluster_shape_to_tma_atom_B( - self.cluster_shape_mn, tiled_mma.thr_id - ) - b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, - b, - b_smem_layout, - self.mma_tiler, - tiled_mma, - self.cluster_layout_vmnk.shape, - ) - # Setup TMA load for SFA sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( self.cluster_shape_mn, tiled_mma.thr_id @@ -774,43 +808,99 @@ def __call__( internal_type=cutlass.Int16, ) - # Setup TMA load for SFB + # Setup TMA ops for B/SFB (shared across all B tensors) + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( self.cluster_shape_mn, tiled_mma.thr_id ) sfb_smem_layout = cute.slice_( self.sfb_smem_layout_staged, (None, None, None, 0) ) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, - sfb, - sfb_smem_layout, - self.mma_tiler_sfb, - tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, - internal_type=cutlass.Int16, - ) - - if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): - x = tma_tensor_sfb.stride[0][1] - y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) - new_shape = ( - (tma_tensor_sfb.shape[0][0], ((2, 2), y)), - tma_tensor_sfb.shape[1], - tma_tensor_sfb.shape[2], + # Helper to create TMA for one B tensor + def _make_tma_b(b_tensor, sfb_tensor): + atom_b, tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + atom_sfb, tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + # This modifies the layout to handle overlapping 256x(# of scale factors + # for a single column of B (nNSF)) logical blocks for SFB when + # cta_tile_shape_n=192. + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tensor_sfb.stride[0][1] + y = cute.ceil_div(tensor_sfb.shape[0][1], 4) + new_shape = ( + (tensor_sfb.shape[0][0], ((2, 2), y)), + tensor_sfb.shape[1], + tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tensor_sfb.stride[1], + tensor_sfb.stride[2], + ) + tensor_sfb = cute.make_tensor( + tensor_sfb.iterator, cute.make_layout(new_shape, stride=new_stride) + ) + return atom_b, tensor_b, atom_sfb, tensor_sfb + + # Create TMA for all B tensors (use const_expr, not loop) + atom_b_0, tensor_b_0, atom_sfb_0, tensor_sfb_0 = _make_tma_b( + b_tuple[0], sfb_tuple[0] + ) + tma_atoms_b = [atom_b_0] + tma_tensors_b = [tensor_b_0] + tma_atoms_sfb = [atom_sfb_0] + tma_tensors_sfb = [tensor_sfb_0] + if cutlass.const_expr(self.num_b_tensors >= 2): + atom_b_1, tensor_b_1, atom_sfb_1, tensor_sfb_1 = _make_tma_b( + b_tuple[1], sfb_tuple[1] ) - # Use right multiplication for ScaledBasis (3 * x instead of x * 3) - x_times_3 = 3 * x - new_stride = ( - (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), - tma_tensor_sfb.stride[1], - tma_tensor_sfb.stride[2], + tma_atoms_b.append(atom_b_1) + tma_tensors_b.append(tensor_b_1) + tma_atoms_sfb.append(atom_sfb_1) + tma_tensors_sfb.append(tensor_sfb_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + atom_b_2, tensor_b_2, atom_sfb_2, tensor_sfb_2 = _make_tma_b( + b_tuple[2], sfb_tuple[2] ) - tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) - tma_tensor_sfb = cute.make_tensor( - tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout + tma_atoms_b.append(atom_b_2) + tma_tensors_b.append(tensor_b_2) + tma_atoms_sfb.append(atom_sfb_2) + tma_tensors_sfb.append(tensor_sfb_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + atom_b_3, tensor_b_3, atom_sfb_3, tensor_sfb_3 = _make_tma_b( + b_tuple[3], sfb_tuple[3] ) + tma_atoms_b.append(atom_b_3) + tma_tensors_b.append(tensor_b_3) + tma_atoms_sfb.append(atom_sfb_3) + tma_tensors_sfb.append(tensor_sfb_3) + tma_atoms_b = tuple(tma_atoms_b) + tma_tensors_b = tuple(tma_tensors_b) + tma_atoms_sfb = tuple(tma_atoms_sfb) + tma_tensors_sfb = tuple(tma_tensors_sfb) + + # Handle alpha tuple (convert to tuple if single tensor) + alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,) a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) @@ -821,7 +911,7 @@ def __call__( ) * atom_thr_size self.tile_sched_params, grid = self._compute_grid( - (a.shape[0], b.shape[0], a.shape[2]), + (a.shape[0], b_tuple[0].shape[0], a.shape[2]), self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters, @@ -921,17 +1011,17 @@ class SharedStorage: tiled_mma_sfb, tma_atom_a, tma_tensor_a, - tma_atom_b, - tma_tensor_b, + tma_atoms_b, # Tuple of TMA atoms for B + tma_tensors_b, # Tuple of TMA tensors for B tma_atom_sfa, tma_tensor_sfa, - tma_atom_sfb, - tma_tensor_sfb, + tma_atoms_sfb, # Tuple of TMA atoms for SFB + tma_tensors_sfb, # Tuple of TMA tensors for SFB out, tile_idx_to_expert_idx, num_non_exiting_tiles, tile_idx_to_mn_limit, - alpha, + alpha_tuple, permuted_idx_to_expanded_idx, token_final_scales, self.cluster_layout_vmnk, @@ -1008,17 +1098,17 @@ def kernel( tiled_mma_sfb: cute.TiledMma, tma_atom_a: cute.CopyAtom, mA_mkl: cute.Tensor, - tma_atom_b: cute.CopyAtom, - mB_nkl: cute.Tensor, + tma_atoms_b: Tuple[cute.CopyAtom, ...], + mB_nkl_tuple: Tuple[cute.Tensor, ...], tma_atom_sfa: cute.CopyAtom, mSFA_mkl: cute.Tensor, - tma_atom_sfb: cute.CopyAtom, - mSFB_nkl: cute.Tensor, + tma_atoms_sfb: Tuple[cute.CopyAtom, ...], + mSFB_nkl_tuple: Tuple[cute.Tensor, ...], out: cute.Tensor, tile_idx_to_expert_idx: cute.Tensor, num_non_exiting_tiles: cute.Tensor, tile_idx_to_mn_limit: cute.Tensor, - alpha: cute.Tensor, + alpha_tuple: Tuple[cute.Tensor, ...], permuted_idx_to_expanded_idx: cute.Tensor, token_final_scales: cute.Tensor, cluster_layout_vmnk: cute.Layout, @@ -1045,9 +1135,19 @@ def kernel( # if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_atom_a) - cpasync.prefetch_descriptor(tma_atom_b) cpasync.prefetch_descriptor(tma_atom_sfa) - cpasync.prefetch_descriptor(tma_atom_sfb) + # Prefetch TMA descriptors for all B tensors using const_expr conditions + cpasync.prefetch_descriptor(tma_atoms_b[0]) + cpasync.prefetch_descriptor(tma_atoms_sfb[0]) + if cutlass.const_expr(self.num_b_tensors >= 2): + cpasync.prefetch_descriptor(tma_atoms_b[1]) + cpasync.prefetch_descriptor(tma_atoms_sfb[1]) + if cutlass.const_expr(self.num_b_tensors >= 3): + cpasync.prefetch_descriptor(tma_atoms_b[2]) + cpasync.prefetch_descriptor(tma_atoms_sfb[2]) + if cutlass.const_expr(self.num_b_tensors >= 4): + cpasync.prefetch_descriptor(tma_atoms_b[3]) + cpasync.prefetch_descriptor(tma_atoms_sfb[3]) use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 @@ -1189,22 +1289,60 @@ def kernel( gA_mkl = cute.local_tile( mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, loopN, loopK, loopL) - gB_nkl = cute.local_tile( - mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + # (bN, bK, loopN, loopK, loopL) - Use const_expr conditions for tuple indexing + gB_nkl_0 = cute.local_tile( + mB_nkl_tuple[0], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gB_nkl_1 = cute.local_tile( + mB_nkl_tuple[1], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gB_nkl_2 = cute.local_tile( + mB_nkl_tuple[2], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gB_nkl_3 = cute.local_tile( + mB_nkl_tuple[3], + cute.slice_(self.mma_tiler, (0, None, None)), + (None, None, None), + ) # (bM, bK, RestM, RestK, RestL) gSFA_mkl = cute.local_tile( mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) ) - # (bN, bK, RestN, RestK, RestL) - gSFB_nkl = cute.local_tile( - mSFB_nkl, + # (bN, bK, RestN, RestK, RestL) - Use const_expr conditions for tuple indexing + gSFB_nkl_0 = cute.local_tile( + mSFB_nkl_tuple[0], cute.slice_(self.mma_tiler_sfb, (0, None, None)), (None, None, None), ) + if cutlass.const_expr(self.num_b_tensors >= 2): + gSFB_nkl_1 = cute.local_tile( + mSFB_nkl_tuple[1], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 3): + gSFB_nkl_2 = cute.local_tile( + mSFB_nkl_tuple[2], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + if cutlass.const_expr(self.num_b_tensors >= 4): + gSFB_nkl_3 = cute.local_tile( + mSFB_nkl_tuple[3], + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) k_tile_cnt = cutlass.Int32(cute.size(gA_mkl, mode=[3])) @@ -1215,12 +1353,24 @@ def kernel( thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) tCgA = thr_mma.partition_A(gA_mkl) - # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) - const_expr conditions + tCgB_0 = thr_mma.partition_B(gB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgB_1 = thr_mma.partition_B(gB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgB_2 = thr_mma.partition_B(gB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgB_3 = thr_mma.partition_B(gB_nkl_3) # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) tCgSFA = thr_mma.partition_A(gSFA_mkl) - # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) - const_expr conditions + tCgSFB_0 = thr_mma_sfb.partition_B(gSFB_nkl_0) + if cutlass.const_expr(self.num_b_tensors >= 2): + tCgSFB_1 = thr_mma_sfb.partition_B(gSFB_nkl_1) + if cutlass.const_expr(self.num_b_tensors >= 3): + tCgSFB_2 = thr_mma_sfb.partition_B(gSFB_nkl_2) + if cutlass.const_expr(self.num_b_tensors >= 4): + tCgSFB_3 = thr_mma_sfb.partition_B(gSFB_nkl_3) # # Partition global/shared tensor for TMA load A/B @@ -1242,15 +1392,87 @@ def kernel( b_cta_layout = cute.make_layout( cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape ) + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + sB_grouped = cute.group_modes(sB, 0, 3) + sSFB_grouped = cute.group_modes(sSFB, 0, 3) + + # TMA partition for B tensor 0 # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), loopM, loopK, loopL) - tBsB, tBgB = cpasync.tma_partition( - tma_atom_b, + # ((atom_v, rest_v), loopN, loopK, loopL) + tBsB_0, tBgB_0 = cpasync.tma_partition( + tma_atoms_b[0], block_in_cluster_coord_vmnk[1], b_cta_layout, - cute.group_modes(sB, 0, 3), - cute.group_modes(tCgB, 0, 3), + sB_grouped, + cute.group_modes(tCgB_0, 0, 3), ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB_0, tBgSFB_0 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[0], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_0, 0, 3), + ) + tBsSFB_0 = cute.filter_zeros(tBsSFB_0) + tBgSFB_0 = cute.filter_zeros(tBgSFB_0) + + # TMA partition for B tensor 1 (tBsB shared memory partition is same for all, use _ to ignore) + if cutlass.const_expr(self.num_b_tensors >= 2): + _, tBgB_1 = cpasync.tma_partition( + tma_atoms_b[1], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_1, 0, 3), + ) + _, tBgSFB_1 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[1], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_1, 0, 3), + ) + tBgSFB_1 = cute.filter_zeros(tBgSFB_1) + + # TMA partition for B tensor 2 + if cutlass.const_expr(self.num_b_tensors >= 3): + _, tBgB_2 = cpasync.tma_partition( + tma_atoms_b[2], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_2, 0, 3), + ) + _, tBgSFB_2 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[2], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_2, 0, 3), + ) + tBgSFB_2 = cute.filter_zeros(tBgSFB_2) + + # TMA partition for B tensor 3 + if cutlass.const_expr(self.num_b_tensors >= 4): + _, tBgB_3 = cpasync.tma_partition( + tma_atoms_b[3], + block_in_cluster_coord_vmnk[1], + b_cta_layout, + sB_grouped, + cute.group_modes(tCgB_3, 0, 3), + ) + _, tBgSFB_3 = cute.nvgpu.cpasync.tma_partition( + tma_atoms_sfb[3], + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + sSFB_grouped, + cute.group_modes(tCgSFB_3, 0, 3), + ) + tBgSFB_3 = cute.filter_zeros(tBgSFB_3) # TMA load SFA partition_S/D sfa_cta_layout = a_cta_layout @@ -1268,22 +1490,6 @@ def kernel( tAsSFA = cute.filter_zeros(tAsSFA) tAgSFA = cute.filter_zeros(tAgSFA) - # TMA load SFB partition_S/D - sfb_cta_layout = cute.make_layout( - cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape - ) - # ((atom_v, rest_v), STAGE) - # ((atom_v, rest_v), RestN, RestK, RestL) - tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( - tma_atom_sfb, - block_in_cluster_coord_sfb_vmnk[1], - sfb_cta_layout, - cute.group_modes(sSFB, 0, 3), - cute.group_modes(tCgSFB, 0, 3), - ) - tBsSFB = cute.filter_zeros(tBsSFB) - tBgSFB = cute.filter_zeros(tBgSFB) - # # Partition shared/tensor memory tensor for TiledMMA_A/B/C # @@ -1479,15 +1685,13 @@ def kernel( tile_info[1], tile_info[2], ) + expert_idx = mma_tile_coord_mnl[2] + # # Slice to per mma tile index # # ((atom_v, rest_v), loopK) tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, 0)] - # ((atom_v, rest_v), loopK) - tBgB_slice = tBgB[ - (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) - ] # ((atom_v, rest_v), RestK) tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)] @@ -1496,9 +1700,6 @@ def kernel( if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): slice_n = mma_tile_coord_mnl[1] // 2 - # ((atom_v, rest_v), RestK) - tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] - # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt ab_producer_state.reset_count() peek_ab_empty_status = cutlass.Boolean(1) @@ -1511,13 +1712,11 @@ def kernel( # for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): # noqa: B007 tAgA_k = tAgA_slice[(None, ab_producer_state.count)] - tBgB_k = tBgB_slice[(None, ab_producer_state.count)] tAgSFA_k = tAgSFA_slice[(None, ab_producer_state.count)] - tBgSFB_k = tBgSFB_slice[(None, ab_producer_state.count)] tAsA_pipe = tAsA[(None, ab_producer_state.index)] - tBsB_pipe = tBsB[(None, ab_producer_state.index)] tAsSFA_pipe = tAsSFA[(None, ab_producer_state.index)] - tBsSFB_pipe = tBsSFB[(None, ab_producer_state.index)] + tBsB_pipe = tBsB_0[(None, ab_producer_state.index)] + tBsSFB_pipe = tBsSFB_0[(None, ab_producer_state.index)] tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) @@ -1526,7 +1725,7 @@ def kernel( ab_producer_state, peek_ab_empty_status ) - # TMA load A/B + # TMA load A cute.copy( tma_atom_a, tAgA_k, @@ -1534,14 +1733,8 @@ def kernel( tma_bar_ptr=tma_bar, mcast_mask=a_full_mcast_mask, ) - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=b_full_mcast_mask, - ) + # TMA load SFA cute.copy( tma_atom_sfa, tAgSFA_k, @@ -1549,13 +1742,304 @@ def kernel( tma_bar_ptr=tma_bar, mcast_mask=sfa_full_mcast_mask, ) - cute.copy( - tma_atom_sfb, - tBgSFB_k, - tBsSFB_pipe, - tma_bar_ptr=tma_bar, - mcast_mask=sfb_full_mcast_mask, - ) + + # Select correct B tensor based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 1): + # Single B tensor - original logic + tBgB_slice = tBgB_0[ + (None, mma_tile_coord_mnl[1], None, expert_idx) + ] + tBgSFB_slice = tBgSFB_0[(None, slice_n, None, expert_idx)] + cute.copy( + tma_atoms_b[0], + tBgB_slice[(None, ab_producer_state.count)], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # Multi-B tensor - select based on expert_idx + if cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_0, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_1, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif cutlass.const_expr(self.num_b_tensors == 3): + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_0, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_1, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_2, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + # 4 B tensors + if expert_idx < self.b_tensor_l_offsets[1]: + local_l_0 = expert_idx - self.b_tensor_l_offsets[0] + cute.copy( + tma_atoms_b[0], + tBgB_0[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_0, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[0], + tBgSFB_0[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_0, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[2]: + local_l_1 = expert_idx - self.b_tensor_l_offsets[1] + cute.copy( + tma_atoms_b[1], + tBgB_1[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_1, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[1], + tBgSFB_1[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_1, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + elif expert_idx < self.b_tensor_l_offsets[3]: + local_l_2 = expert_idx - self.b_tensor_l_offsets[2] + cute.copy( + tma_atoms_b[2], + tBgB_2[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_2, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[2], + tBgSFB_2[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_2, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + else: + local_l_3 = expert_idx - self.b_tensor_l_offsets[3] + cute.copy( + tma_atoms_b[3], + tBgB_3[ + ( + None, + mma_tile_coord_mnl[1], + ab_producer_state.count, + local_l_3, + ) + ], + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atoms_sfb[3], + tBgSFB_3[ + ( + None, + slice_n, + ab_producer_state.count, + local_l_3, + ) + ], + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 ab_producer_state.advance() @@ -1903,7 +2387,48 @@ def kernel( # expert_idx = mma_tile_coord_mnl[2] - alpha_val = alpha[expert_idx] + + # Select alpha from correct tensor based on expert_idx + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + if cutlass.const_expr(self.num_b_tensors == 1): + pass # Already initialized above + elif cutlass.const_expr(self.num_b_tensors == 2): + if expert_idx >= self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[1][ + expert_idx - self.b_tensor_l_offsets[1] + ] + elif cutlass.const_expr(self.num_b_tensors == 3): + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][ + expert_idx - self.b_tensor_l_offsets[1] + ] + elif expert_idx >= self.b_tensor_l_offsets[2]: + alpha_val = alpha_tuple[2][ + expert_idx - self.b_tensor_l_offsets[2] + ] + else: + # 4 B tensors + if ( + expert_idx >= self.b_tensor_l_offsets[1] + and expert_idx < self.b_tensor_l_offsets[2] + ): + alpha_val = alpha_tuple[1][ + expert_idx - self.b_tensor_l_offsets[1] + ] + elif ( + expert_idx >= self.b_tensor_l_offsets[2] + and expert_idx < self.b_tensor_l_offsets[3] + ): + alpha_val = alpha_tuple[2][ + expert_idx - self.b_tensor_l_offsets[2] + ] + elif expert_idx >= self.b_tensor_l_offsets[3]: + alpha_val = alpha_tuple[3][ + expert_idx - self.b_tensor_l_offsets[3] + ] tile_m_start = tile_info[0] * self.cta_tile_shape_mnk[0] permuted_row = tile_m_start + epi_tidx @@ -2622,11 +3147,11 @@ def can_implement( def wrapper( self, a_ptr: cute.Pointer, - b_ptr: cute.Pointer, + b_ptr_tuple: Tuple[cute.Pointer, ...], a_sf_ptr: cute.Pointer, - b_sf_ptr: cute.Pointer, + b_sf_ptr_tuple: Tuple[cute.Pointer, ...], c_ptr: cute.Pointer, - alpha_ptr: cute.Pointer, + alpha_ptr_tuple: Tuple[cute.Pointer, ...], tile_idx_to_group_idx_ptr: cute.Pointer, tile_idx_to_mn_limit_ptr: cute.Pointer, permuted_idx_to_expanded_idx_ptr: cute.Pointer, @@ -2635,7 +3160,6 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, - l: cutlass.Int64, # noqa: E741 num_tokens: cutlass.Int64, top_k: cutlass.Int64, tile_size: cutlass.Constexpr, @@ -2644,30 +3168,90 @@ def wrapper( stream: cuda.CUstream, epilogue_op: cutlass.Constexpr = lambda x: x, ): + """Unified wrapper supporting both single-B and multi-B tensors. + + B tensors are always passed as tuples (length 1 for single-B). + L sizes are configured via b_tensor_l_sizes in __init__. + """ scale_k = k // scaling_vector_size num_tiles = m // tile_size + a = cute.make_tensor( a_ptr, layout=cute.make_ordered_layout((m, k, 1), order=(1, 0, 2)) ) - b = cute.make_tensor( - b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2)) - ) a_sf = cute.make_tensor( a_sf_ptr, layout=cute.make_ordered_layout( (32, 4, m // 128, 4, scale_k // 4, 1), order=(2, 1, 4, 0, 3, 5) ), ) - b_sf = cute.make_tensor( - b_sf_ptr, - layout=cute.make_ordered_layout( - (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) - ), - ) c = cute.make_tensor( c_ptr, layout=cute.make_ordered_layout((num_tokens, n, 1), order=(1, 0, 2)) ) - alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + # Create B and alpha tensors using const_expr conditions + l_0 = self.b_tensor_l_sizes[0] + alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) + b_0 = cute.make_tensor( + b_ptr_tuple[0], layout=cute.make_ordered_layout((n, k, l_0), order=(1, 0, 2)) + ) + b_sf_0 = cute.make_tensor( + b_sf_ptr_tuple[0], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_0), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple = [b_0] + b_sf_tuple = [b_sf_0] + alpha_tuple = [alpha_0] + + if cutlass.const_expr(self.num_b_tensors >= 2): + l_1 = self.b_tensor_l_sizes[1] + alpha_1 = cute.make_tensor(alpha_ptr_tuple[1], layout=cute.make_layout((l_1,))) + b_1 = cute.make_tensor( + b_ptr_tuple[1], layout=cute.make_ordered_layout((n, k, l_1), order=(1, 0, 2)) + ) + b_sf_1 = cute.make_tensor( + b_sf_ptr_tuple[1], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_1), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_1) + b_sf_tuple.append(b_sf_1) + alpha_tuple.append(alpha_1) + + if cutlass.const_expr(self.num_b_tensors >= 3): + l_2 = self.b_tensor_l_sizes[2] + alpha_2 = cute.make_tensor(alpha_ptr_tuple[2], layout=cute.make_layout((l_2,))) + b_2 = cute.make_tensor( + b_ptr_tuple[2], layout=cute.make_ordered_layout((n, k, l_2), order=(1, 0, 2)) + ) + b_sf_2 = cute.make_tensor( + b_sf_ptr_tuple[2], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_2), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_2) + b_sf_tuple.append(b_sf_2) + alpha_tuple.append(alpha_2) + + if cutlass.const_expr(self.num_b_tensors >= 4): + l_3 = self.b_tensor_l_sizes[3] + alpha_3 = cute.make_tensor(alpha_ptr_tuple[3], layout=cute.make_layout((l_3,))) + b_3 = cute.make_tensor( + b_ptr_tuple[3], layout=cute.make_ordered_layout((n, k, l_3), order=(1, 0, 2)) + ) + b_sf_3 = cute.make_tensor( + b_sf_ptr_tuple[3], + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l_3), order=(2, 1, 4, 0, 3, 5) + ), + ) + b_tuple.append(b_3) + b_sf_tuple.append(b_sf_3) + alpha_tuple.append(alpha_3) tile_idx_to_group_idx = cute.make_tensor( tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) @@ -2688,14 +3272,14 @@ def wrapper( return self( a, - b, + tuple(b_tuple), c, a_sf, - b_sf, + tuple(b_sf_tuple), tile_idx_to_group_idx, num_non_exiting_tiles, tile_idx_to_mn_limit, - alpha, + tuple(alpha_tuple), max_active_clusters=max_active_clusters, stream=stream, permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, diff --git a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index cf40bd0136..6b27c0b763 100644 --- a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -43,7 +43,7 @@ - Gather: Uses LDGSTS to gather A directly using token_id_mapping, no moe_permute needed """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import cutlass import cutlass.cute as cute @@ -191,15 +191,14 @@ def _get_compiled_gather_kernel( permuted_m: int, n: int, # This is 2*intermediate_size k: int, - num_experts: int, # Tensor pointers (runtime parameters - NOT in cache key) a_ptr, - b_ptr, + b_ptr, # tuple of pointers a_sf_ptr, - b_sf_ptr, + b_sf_ptr, # tuple of pointers c_ptr, c_sf_ptr, - alpha_ptr, + alpha_ptr, # tuple of pointers tile_idx_ptr, mn_limit_ptr, token_id_ptr, @@ -221,6 +220,7 @@ def _get_compiled_gather_kernel( vectorized_f32: bool, raster_along_m: bool, enable_pdl: bool = True, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Get or compile the gather grouped GEMM with SwiGLU kernel. @@ -234,10 +234,14 @@ def _get_compiled_gather_kernel( This matches TRT-LLM's approach where the same compiled kernel can be reused for different problem sizes, significantly reducing JIT compilation overhead during autotuning. + + Supports multiple B weight tensors via b_tensor_l_sizes parameter. + When b_tensor_l_sizes is provided, b_ptr/b_sf_ptr/alpha_ptr are tuples. """ global _gather_kernel_cache # Cache key includes dtype and tactic parameters, NOT problem dimensions + # Also includes b_tensor_l_sizes since kernel is specialized per multi-B config cache_key = ( ab_dtype, sf_dtype, @@ -250,6 +254,7 @@ def _get_compiled_gather_kernel( vectorized_f32, raster_along_m, enable_pdl, + b_tensor_l_sizes, ) if cache_key not in _gather_kernel_cache: @@ -262,16 +267,16 @@ def _get_compiled_gather_kernel( topk=topk, raster_along_m=raster_along_m, enable_pdl=enable_pdl, + b_tensor_l_sizes=b_tensor_l_sizes, ) # Compile with runtime parameters - they can vary across calls # Order must match wrapper signature: - # (a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, alpha_ptr, + # (a_ptr, b_ptr_tuple, a_sf_ptr, b_sf_ptr_tuple, c_ptr, c_sf_ptr, alpha_ptr_tuple, # tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, token_id_mapping_ptr, - # num_non_exiting_tiles_ptr, global_sf_ptr, orig_m, m, n, k, l, + # num_non_exiting_tiles_ptr, norm_const_ptr, orig_m, m, n, k, # tile_size, scaling_vector_size, max_active_clusters, stream) - compiled_gemm = cute.compile( - gemm.wrapper, + compile_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -288,7 +293,11 @@ def _get_compiled_gather_kernel( permuted_m, n, k, - num_experts, + ] + + compiled_gemm = cute.compile( + gemm.wrapper, + *compile_args, tile_size=tile_size, scaling_vector_size=sf_vec_size, max_active_clusters=max_active_clusters, @@ -302,10 +311,10 @@ def _get_compiled_gather_kernel( def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( a: torch.Tensor, - b: torch.Tensor, + b: Union[torch.Tensor, List[torch.Tensor]], a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, + b_scale: Union[torch.Tensor, List[torch.Tensor]], + alpha: Union[torch.Tensor, List[torch.Tensor]], tile_idx_to_expert_idx: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, token_id_mapping: torch.Tensor, @@ -406,14 +415,19 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( ... topk=topk, ... ) # out shape: (valid_m, intermediate_dim) """ + # Normalize to lists for multi-B support + b_list = [b] if isinstance(b, torch.Tensor) else b + b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale + alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha + # Validate inputs assert a.device.type == "cuda", "Input tensors must be on CUDA device" - assert b.device.type == "cuda", "Input tensors must be on CUDA device" + assert b_list[0].device.type == "cuda", "Input tensors must be on CUDA device" # Get dimensions seq_len = a.shape[0] - num_experts = b.shape[0] - n = b.shape[1] # This is 2*intermediate_size + num_experts = sum(bi.size(0) for bi in b_list) + n = b_list[0].shape[1] # This is 2*intermediate_size k = a.shape[1] if ab_dtype == "float4_e2m1fn": k = k * 2 # FP4 is packed 2 elements per byte @@ -500,19 +514,16 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( # Get tile_size from mma_tiler_mn tile_size = mma_tiler_mn[0] + # Compute b_tensor_l_sizes for multi-B support + b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list) + # Create raw pointers (TRT-LLM style) - allows same compiled kernel for different sizes a_ptr = make_ptr( ab_dtype_cutlass, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) - b_ptr = make_ptr( - ab_dtype_cutlass, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) a_sf_ptr = make_ptr( sf_dtype_cutlass, a_scale.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) - b_sf_ptr = make_ptr( - sf_dtype_cutlass, b_scale.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) c_ptr = make_ptr( c_dtype_cutlass, out.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) @@ -531,7 +542,24 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( c_sf_ptr = None norm_const_ptr = None - alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), cute.AddressSpace.gmem) + # Create pointer tuples for B tensors + b_ptr = tuple( + make_ptr( + ab_dtype_cutlass, bi.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + for bi in b_list + ) + b_sf_ptr = tuple( + make_ptr( + sf_dtype_cutlass, bsi.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + for bsi in b_scale_list + ) + alpha_ptr = tuple( + make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem) + for ai in alpha_list + ) + tile_idx_ptr = make_ptr( cutlass.Int32, tile_idx_to_expert_idx.data_ptr(), cute.AddressSpace.gmem ) @@ -549,15 +577,12 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( torch_stream = torch.cuda.current_stream() stream = cuda.CUstream(torch_stream.cuda_stream) - # Get or compile the kernel (cached by dtype and tactic parameters) + # Get or compile the kernel compiled_gemm = _get_compiled_gather_kernel( - # Runtime parameters (problem dimensions) orig_m=seq_len, permuted_m=permuted_m, n=n, k=k, - num_experts=num_experts, - # Tensor pointers (order must match wrapper signature) a_ptr=a_ptr, b_ptr=b_ptr, a_sf_ptr=a_sf_ptr, @@ -572,11 +597,9 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( norm_const_ptr=norm_const_ptr, max_active_clusters=max_active_clusters, stream=stream, - # Dtype parameters (compile-time, in cache key) ab_dtype=ab_dtype, sf_dtype=sf_dtype, c_dtype=c_dtype, - # Tactic parameters (compile-time, cached) sf_vec_size=sf_vec_size, tile_size=tile_size, topk=topk, @@ -585,14 +608,11 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( vectorized_f32=vectorized_f32, raster_along_m=raster_along_m, enable_pdl=enable_pdl, + b_tensor_l_sizes=b_tensor_l_sizes, ) - # Execute kernel with runtime parameters - # Order must match wrapper signature: - # (a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, c_sf_ptr, alpha_ptr, - # tile_idx_ptr, mn_limit_ptr, token_id_ptr, num_tiles_ptr, global_sf_ptr, - # orig_m, m, n, k, l, stream) - compiled_gemm( + # Execute kernel + exec_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -609,8 +629,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( permuted_m, n, k, - num_experts, - stream=stream, - ) + ] + compiled_gemm(*exec_args, stream=stream) return out, out_scale if generate_sfc else None diff --git a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index 17cdecce20..be4dad8876 100644 --- a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -40,7 +40,7 @@ - Support for SM100 (Blackwell) architecture """ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import cutlass import cutlass.cute as cute @@ -165,15 +165,14 @@ def _get_compiled_finalize_kernel( permuted_m: int, n: int, k: int, - num_experts: int, topk: int, # Tensor pointers (runtime parameters - NOT in cache key) a_ptr, - b_ptr, + b_ptr, # tuple of pointers a_sf_ptr, - b_sf_ptr, + b_sf_ptr, # tuple of pointers c_ptr, - alpha_ptr, + alpha_ptr, # tuple of pointers tile_idx_ptr, mn_limit_ptr, permuted_idx_ptr, @@ -188,6 +187,7 @@ def _get_compiled_finalize_kernel( cluster_shape_mn: Tuple[int, int], raster_along_m: bool, enable_pdl: bool = True, + b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, ): """Get or compile the grouped GEMM with finalize fusion kernel. @@ -197,10 +197,14 @@ def _get_compiled_finalize_kernel( This matches TRT-LLM's approach where the same compiled kernel can be reused for different problem sizes, significantly reducing JIT compilation overhead during autotuning. + + Supports multiple B weight tensors via b_tensor_l_sizes parameter. + When b_tensor_l_sizes is provided, b_ptr/b_sf_ptr/alpha_ptr are tuples. """ global _finalize_kernel_cache # Cache key only includes tactic parameters, NOT problem dimensions + # Also includes b_tensor_l_sizes since kernel is specialized per multi-B config cache_key = ( sf_vec_size, tile_size, @@ -208,6 +212,7 @@ def _get_compiled_finalize_kernel( cluster_shape_mn, raster_along_m, enable_pdl, + b_tensor_l_sizes, ) if cache_key not in _finalize_kernel_cache: @@ -219,17 +224,17 @@ def _get_compiled_finalize_kernel( use_blkred=True, raster_along_m=raster_along_m, enable_pdl=enable_pdl, + b_tensor_l_sizes=b_tensor_l_sizes, ) # Compile with runtime parameters - they can vary across calls # Order must match wrapper signature: - # (a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, alpha_ptr, + # (a_ptr, b_ptr_tuple, a_sf_ptr, b_sf_ptr_tuple, c_ptr, alpha_ptr_tuple, # tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, # permuted_idx_to_expanded_idx_ptr, num_non_exiting_tiles_ptr, - # token_final_scales_ptr, m, n, k, l, num_tokens, top_k, + # token_final_scales_ptr, m, n, k, num_tokens, top_k, # tile_size, scaling_vector_size, max_active_clusters, stream) - compiled_gemm = cute.compile( - gemm.wrapper, + compile_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -244,9 +249,13 @@ def _get_compiled_finalize_kernel( permuted_m, n, k, - num_experts, seq_len, topk, + ] + + compiled_gemm = cute.compile( + gemm.wrapper, + *compile_args, tile_size=tile_size, scaling_vector_size=sf_vec_size, max_active_clusters=max_active_clusters, @@ -260,10 +269,10 @@ def _get_compiled_finalize_kernel( def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( a: torch.Tensor, - b: torch.Tensor, + b: Union[torch.Tensor, List[torch.Tensor]], a_scale: torch.Tensor, - b_scale: torch.Tensor, - alpha: torch.Tensor, + b_scale: Union[torch.Tensor, List[torch.Tensor]], + alpha: Union[torch.Tensor, List[torch.Tensor]], tile_idx_to_expert_idx: torch.Tensor, num_non_exiting_tiles: torch.Tensor, tile_idx_to_mn_limit: torch.Tensor, @@ -366,14 +375,19 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( ... token_final_scales=final_scales, ... ) # out shape: (seq_len, hidden_dim) """ + # Normalize to lists for multi-B support + b_list = [b] if isinstance(b, torch.Tensor) else b + b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale + alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha + # Validate inputs assert a.device.type == "cuda", "Input tensors must be on CUDA device" - assert b.device.type == "cuda", "Input tensors must be on CUDA device" + assert b_list[0].device.type == "cuda", "Input tensors must be on CUDA device" # Get dimensions permuted_m = a.shape[0] - num_experts = b.shape[0] - n = b.shape[1] + num_experts = sum(bi.size(0) for bi in b_list) + n = b_list[0].shape[1] k = a.shape[1] if ab_dtype == "float4_e2m1fn": k = k * 2 # FP4 is packed 2 elements per byte @@ -439,24 +453,38 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( # Get tile_size from mma_tiler_mn tile_size = mma_tiler_mn[0] + # Compute b_tensor_l_sizes for multi-B support + b_tensor_l_sizes = tuple(bi.size(0) for bi in b_list) + # Create raw pointers (TRT-LLM style) - allows same compiled kernel for different sizes a_ptr = make_ptr( ab_dtype_cutlass, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) - b_ptr = make_ptr( - ab_dtype_cutlass, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 - ) a_sf_ptr = make_ptr( sf_dtype_cutlass, a_scale.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 ) - b_sf_ptr = make_ptr( - sf_dtype_cutlass, b_scale.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 - ) c_ptr = make_ptr( out_dtype_cutlass, out.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 ) - alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), cute.AddressSpace.gmem) + # Create pointer tuples for B tensors + b_ptr = tuple( + make_ptr( + ab_dtype_cutlass, bi.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + for bi in b_list + ) + b_sf_ptr = tuple( + make_ptr( + sf_dtype_cutlass, bsi.data_ptr(), cute.AddressSpace.gmem, assumed_align=16 + ) + for bsi in b_scale_list + ) + alpha_ptr = tuple( + make_ptr(cutlass.Float32, ai.data_ptr(), cute.AddressSpace.gmem) + for ai in alpha_list + ) + tile_idx_ptr = make_ptr( cutlass.Int32, tile_idx_to_expert_idx.data_ptr(), cute.AddressSpace.gmem ) @@ -488,16 +516,13 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( torch_stream = torch.cuda.current_stream() stream = cuda.CUstream(torch_stream.cuda_stream) - # Get or compile the kernel (cached by tactic parameters only) + # Get or compile the kernel compiled_gemm = _get_compiled_finalize_kernel( - # Runtime parameters (problem dimensions) seq_len=seq_len, permuted_m=permuted_m, n=n, k=k, - num_experts=num_experts, topk=topk, - # Tensor pointers (order must match wrapper signature) a_ptr=a_ptr, b_ptr=b_ptr, a_sf_ptr=a_sf_ptr, @@ -511,21 +536,17 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( token_scales_ptr=token_scales_ptr, max_active_clusters=max_active_clusters, stream=stream, - # Tactic parameters (compile-time, cached) sf_vec_size=sf_vec_size, tile_size=tile_size, mma_tiler_mn=mma_tiler_mn, cluster_shape_mn=cluster_shape_mn, raster_along_m=raster_along_m, enable_pdl=enable_pdl, + b_tensor_l_sizes=b_tensor_l_sizes, ) - # Execute kernel with runtime parameters - # Order must match wrapper signature: - # (a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, alpha_ptr, tile_idx_ptr, - # mn_limit_ptr, permuted_idx_ptr, num_tiles_ptr, token_scales_ptr, - # m, n, k, l, num_tokens, top_k, stream) - compiled_gemm( + # Execute kernel + exec_args = [ a_ptr, b_ptr, a_sf_ptr, @@ -540,10 +561,9 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( permuted_m, n, k, - num_experts, seq_len, topk, - stream=stream, - ) + ] + compiled_gemm(*exec_args, stream=stream) return out diff --git a/flashinfer/fused_moe/cute_dsl/fused_moe.py b/flashinfer/fused_moe/cute_dsl/fused_moe.py index 8ed6a8ba72..0e3815c498 100644 --- a/flashinfer/fused_moe/cute_dsl/fused_moe.py +++ b/flashinfer/fused_moe/cute_dsl/fused_moe.py @@ -49,7 +49,7 @@ >>> g.replay() """ -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -104,16 +104,16 @@ def _moe_core_impl( # Routing token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, - # GEMM1 weights - w1_weight: torch.Tensor, - w1_weight_sf: torch.Tensor, - w1_alpha: torch.Tensor, + # GEMM1 weights (single tensor or list for multi-B/DWDP) + w1_weight: Union[torch.Tensor, List[torch.Tensor]], + w1_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w1_alpha: Union[torch.Tensor, List[torch.Tensor]], # GEMM2 intermediate scale fc2_input_scale: torch.Tensor, - # GEMM2 weights - w2_weight: torch.Tensor, - w2_weight_sf: torch.Tensor, - w2_alpha: torch.Tensor, + # GEMM2 weights (single tensor or list for multi-B/DWDP) + w2_weight: Union[torch.Tensor, List[torch.Tensor]], + w2_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w2_alpha: Union[torch.Tensor, List[torch.Tensor]], # MoE config num_experts: int, top_k: int, @@ -182,7 +182,8 @@ def _moe_core_impl( Output tensor [num_tokens, hidden_size]. """ num_tokens = token_selected_experts.size(0) - hidden_size = w2_weight.size(1) + _w2 = w2_weight[0] if isinstance(w2_weight, list) else w2_weight + hidden_size = _w2.size(1) # Allocate output if not provided. The caller (wrapper or functional # API) should pass a [:num_tokens] slice of the pre-allocated buffer @@ -462,13 +463,13 @@ def _forward_with_tactic( x_sf: torch.Tensor, token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, - w1_weight: torch.Tensor, - w1_weight_sf: torch.Tensor, - w1_alpha: torch.Tensor, + w1_weight: Union[torch.Tensor, List[torch.Tensor]], + w1_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w1_alpha: Union[torch.Tensor, List[torch.Tensor]], fc2_input_scale: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_sf: torch.Tensor, - w2_alpha: torch.Tensor, + w2_weight: Union[torch.Tensor, List[torch.Tensor]], + w2_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w2_alpha: Union[torch.Tensor, List[torch.Tensor]], num_experts: int, top_k: int, num_local_experts: int, @@ -538,13 +539,13 @@ def run( x_sf: torch.Tensor, token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, - w1_weight: torch.Tensor, - w1_weight_sf: torch.Tensor, - w1_alpha: torch.Tensor, + w1_weight: Union[torch.Tensor, List[torch.Tensor]], + w1_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w1_alpha: Union[torch.Tensor, List[torch.Tensor]], fc2_input_scale: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_sf: torch.Tensor, - w2_alpha: torch.Tensor, + w2_weight: Union[torch.Tensor, List[torch.Tensor]], + w2_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w2_alpha: Union[torch.Tensor, List[torch.Tensor]], tactic: Optional[Tuple] = None, ) -> torch.Tensor: """Run MoE computation. @@ -557,13 +558,15 @@ def run( x_sf: Scale factors for x. token_selected_experts: Expert assignments [num_tokens, top_k]. token_final_scales: Routing weights [num_tokens, top_k]. - w1_weight: GEMM1 weights (gate + up fused). - w1_weight_sf: Scale factors for w1_weight. - w1_alpha: Per-expert global scale for GEMM1. + w1_weight: GEMM1 weights (gate + up fused). Single tensor OR list of + tensors for multi-B / DWDP (up to 4 tensors); when a list, the + expert dim is split across entries. + w1_weight_sf: Scale factors for w1_weight (same Tensor-or-list convention). + w1_alpha: Per-expert global scale for GEMM1 (same Tensor-or-list convention). fc2_input_scale: Global scale for GEMM2 input quantization. - w2_weight: GEMM2 weights (down projection). - w2_weight_sf: Scale factors for w2_weight. - w2_alpha: Per-expert global scale for GEMM2. + w2_weight: GEMM2 weights (down projection) (same Tensor-or-list convention). + w2_weight_sf: Scale factors for w2_weight (same Tensor-or-list convention). + w2_alpha: Per-expert global scale for GEMM2 (same Tensor-or-list convention). tactic: Tactic tuple or None for auto-selection. Returns: @@ -634,13 +637,13 @@ def _cute_dsl_fused_moe_nvfp4_impl( x_sf: torch.Tensor, token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, - w1_weight: torch.Tensor, - w1_weight_sf: torch.Tensor, - w1_alpha: torch.Tensor, + w1_weight: Union[torch.Tensor, List[torch.Tensor]], + w1_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w1_alpha: Union[torch.Tensor, List[torch.Tensor]], fc2_input_scale: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_sf: torch.Tensor, - w2_alpha: torch.Tensor, + w2_weight: Union[torch.Tensor, List[torch.Tensor]], + w2_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w2_alpha: Union[torch.Tensor, List[torch.Tensor]], num_experts: int, top_k: int, num_local_experts: int, @@ -693,13 +696,13 @@ def cute_dsl_fused_moe_nvfp4( x_sf: torch.Tensor, token_selected_experts: torch.Tensor, token_final_scales: torch.Tensor, - w1_weight: torch.Tensor, - w1_weight_sf: torch.Tensor, - w1_alpha: torch.Tensor, + w1_weight: Union[torch.Tensor, List[torch.Tensor]], + w1_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w1_alpha: Union[torch.Tensor, List[torch.Tensor]], fc2_input_scale: torch.Tensor, - w2_weight: torch.Tensor, - w2_weight_sf: torch.Tensor, - w2_alpha: torch.Tensor, + w2_weight: Union[torch.Tensor, List[torch.Tensor]], + w2_weight_sf: Union[torch.Tensor, List[torch.Tensor]], + w2_alpha: Union[torch.Tensor, List[torch.Tensor]], num_experts: int, top_k: int, num_local_experts: Optional[int] = None, @@ -727,13 +730,16 @@ def cute_dsl_fused_moe_nvfp4( x_sf: Scale factors for x. token_selected_experts: Expert assignments [num_tokens, top_k]. token_final_scales: Routing weights [num_tokens, top_k]. - w1_weight: GEMM1 weights (gate + up fused). - w1_weight_sf: Scale factors for w1_weight. - w1_alpha: Per-expert global scale for GEMM1. + w1_weight: GEMM1 weights (gate + up fused). Single tensor OR list of + tensors for multi-B / DWDP (Distributed Weight Data Parallelism), + up to 4 tensors. When a list, the expert dimension is split across + entries (sum of shape[0] = total local experts). + w1_weight_sf: Scale factors for w1_weight. Same Tensor-or-list convention. + w1_alpha: Per-expert global scale for GEMM1. Same Tensor-or-list convention. fc2_input_scale: Global scale for GEMM2 input quantization. - w2_weight: GEMM2 weights (down projection). - w2_weight_sf: Scale factors for w2_weight. - w2_alpha: Per-expert global scale for GEMM2. + w2_weight: GEMM2 weights (down projection). Same Tensor-or-list convention. + w2_weight_sf: Scale factors for w2_weight. Same Tensor-or-list convention. + w2_alpha: Per-expert global scale for GEMM2. Same Tensor-or-list convention. num_experts: Total number of experts. top_k: Number of experts per token. num_local_experts: Local experts for EP. Default: num_experts. @@ -750,7 +756,8 @@ def cute_dsl_fused_moe_nvfp4( num_local_experts = num_experts num_tokens = token_selected_experts.size(0) - hidden_size = w2_weight.size(1) + _w2 = w2_weight[0] if isinstance(w2_weight, list) else w2_weight + hidden_size = _w2.size(1) if moe_output is None: moe_output = torch.empty( diff --git a/flashinfer/fused_moe/cute_dsl/tuner.py b/flashinfer/fused_moe/cute_dsl/tuner.py index 0cc8628ed9..38c92e47da 100644 --- a/flashinfer/fused_moe/cute_dsl/tuner.py +++ b/flashinfer/fused_moe/cute_dsl/tuner.py @@ -348,15 +348,22 @@ def get_valid_tactics( # type: ignore[override] # Extract problem dimensions from inputs: # 0: x (num_tokens, hidden_size//2) - # 4: w1_weight (num_local_experts, 2*intermediate_size, hidden_size//2) - # 8: w2_weight (num_local_experts, hidden_size, intermediate_size//2) + # 4: w1_weight — Tensor OR List[Tensor] (multi-B/DWDP): + # (num_local_experts, 2*intermediate_size, hidden_size//2) + # 8: w2_weight — Tensor OR List[Tensor]: + # (num_local_experts, hidden_size, intermediate_size//2) + # For multi-B, num_local_experts is the sum across all B tensors. x = inputs[0] w1_weight = inputs[4] num_tokens = x.shape[0] hidden_size = x.shape[1] * 2 # FP4 packed - num_local_experts = w1_weight.shape[0] - intermediate_size = w1_weight.shape[1] // 2 # gate+up fused + if isinstance(w1_weight, (list, tuple)): + num_local_experts = sum(t.shape[0] for t in w1_weight) + intermediate_size = w1_weight[0].shape[1] // 2 # gate+up fused + else: + num_local_experts = w1_weight.shape[0] + intermediate_size = w1_weight.shape[1] // 2 # gate+up fused # Fixed dtypes/layouts for NVFP4 MoE ab_dtype = cutlass.Float4E2M1FN diff --git a/tests/moe/test_cute_dsl_fused_moe.py b/tests/moe/test_cute_dsl_fused_moe.py index 55d2497fba..e764cac3b2 100644 --- a/tests/moe/test_cute_dsl_fused_moe.py +++ b/tests/moe/test_cute_dsl_fused_moe.py @@ -1053,5 +1053,412 @@ def test_all_tactics_accuracy( ) +# ============================================================================= +# Test Class: Multi-B Tensor (DWDP) Support +# ============================================================================= + + +@cute_dsl_available +@sm100_required +class TestMultiBTensor: + """Tests for multi-B tensor (DWDP) support. + + These tests verify that splitting expert weights into multiple tensors + produces the same results as a single stacked tensor. + """ + + @pytest.mark.parametrize( + "num_tokens, hidden_size, intermediate_size, num_experts, top_k", + [ + (128, 512, 512, 8, 2), + (256, 512, 512, 8, 4), + ], + ) + def test_multi_b_fused_moe_2_tensors( + self, num_tokens, hidden_size, intermediate_size, num_experts, top_k + ): + """Test end-to-end MoE with weights split into 2 tensors.""" + from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + + tensors = create_moe_tensors( + num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k + ) + + # Run with single tensor (baseline) + single_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + num_experts=num_experts, + top_k=top_k, + ) + + # Split weights into 2 tensors along expert dimension. + # - weight: (num_experts, ...) — split on dim 0 (outermost in row-major) + # - weight_sf: MMA layout (32, 4, m_tiles, 4, k_tiles, num_experts). + # Physical storage from convert_sf_to_mma_layout is + # (num_experts, m_tiles, k_tiles, 32, 4, 4), i.e. num_experts is the + # outermost physical dim (largest stride). Slicing [..., :split] yields + # a strided view whose data is a contiguous prefix of the original + # memory — the kernel takes the raw .data_ptr() and imposes its own + # layout, so we must NOT call .contiguous() (that would force a + # row-major copy making num_experts the innermost dim, breaking the + # kernel's expected stride pattern). + # - alpha: (num_experts,) — split on dim 0 + split = num_experts // 2 + w1_list = [tensors["w1_weight"][:split], tensors["w1_weight"][split:]] + w1_sf_list = [ + tensors["w1_weight_sf"][..., :split], + tensors["w1_weight_sf"][..., split:], + ] + w1_alpha_list = [tensors["w1_alpha"][:split], tensors["w1_alpha"][split:]] + w2_list = [tensors["w2_weight"][:split], tensors["w2_weight"][split:]] + w2_sf_list = [ + tensors["w2_weight_sf"][..., :split], + tensors["w2_weight_sf"][..., split:], + ] + w2_alpha_list = [tensors["w2_alpha"][:split], tensors["w2_alpha"][split:]] + + # Run with multi-B (2 tensors) + multi_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=w1_list, + w1_weight_sf=w1_sf_list, + w1_alpha=w1_alpha_list, + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=w2_list, + w2_weight_sf=w2_sf_list, + w2_alpha=w2_alpha_list, + num_experts=num_experts, + top_k=top_k, + ) + + # Compare outputs + passed, percent_within, atol = check_accuracy(multi_output, single_output) + assert passed, ( + f"Multi-B (2 tensors) output mismatch: " + f"{percent_within * 100:.2f}% within tolerance (atol={atol:.4f})" + ) + + @pytest.mark.parametrize( + "num_tokens, hidden_size, intermediate_size, num_experts, top_k", + [ + (128, 512, 512, 8, 2), + ], + ) + def test_multi_b_uneven_split( + self, num_tokens, hidden_size, intermediate_size, num_experts, top_k + ): + """Test multi-B with uneven expert split (e.g., 3+5 for 8 experts).""" + from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + + tensors = create_moe_tensors( + num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k + ) + + # Run with single tensor (baseline) + single_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + num_experts=num_experts, + top_k=top_k, + ) + + # Split weights unevenly: 3 + 5. SF is sliced without .contiguous() to + # preserve the MMA-layout stride pattern (see 2-tensor test for details). + s = 3 + w1_list = [tensors["w1_weight"][:s], tensors["w1_weight"][s:]] + w1_sf_list = [ + tensors["w1_weight_sf"][..., :s], + tensors["w1_weight_sf"][..., s:], + ] + w1_alpha_list = [tensors["w1_alpha"][:s], tensors["w1_alpha"][s:]] + w2_list = [tensors["w2_weight"][:s], tensors["w2_weight"][s:]] + w2_sf_list = [ + tensors["w2_weight_sf"][..., :s], + tensors["w2_weight_sf"][..., s:], + ] + w2_alpha_list = [tensors["w2_alpha"][:s], tensors["w2_alpha"][s:]] + + # Run with multi-B (uneven split) + multi_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=w1_list, + w1_weight_sf=w1_sf_list, + w1_alpha=w1_alpha_list, + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=w2_list, + w2_weight_sf=w2_sf_list, + w2_alpha=w2_alpha_list, + num_experts=num_experts, + top_k=top_k, + ) + + # Compare outputs + passed, percent_within, atol = check_accuracy(multi_output, single_output) + assert passed, ( + f"Multi-B (uneven split 3+5) output mismatch: " + f"{percent_within * 100:.2f}% within tolerance (atol={atol:.4f})" + ) + + @pytest.mark.parametrize( + "num_tokens, hidden_size, intermediate_size, num_experts, top_k", + [ + (128, 512, 512, 8, 2), + ], + ) + def test_single_b_list_backward_compat( + self, num_tokens, hidden_size, intermediate_size, num_experts, top_k + ): + """Test that passing a single tensor wrapped in a list produces identical results.""" + from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + + tensors = create_moe_tensors( + num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k + ) + + # Run with single tensor + single_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + num_experts=num_experts, + top_k=top_k, + ) + + # Run with single tensor wrapped in a list (should use same code path) + list_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=[tensors["w1_weight"]], + w1_weight_sf=[tensors["w1_weight_sf"]], + w1_alpha=[tensors["w1_alpha"]], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=[tensors["w2_weight"]], + w2_weight_sf=[tensors["w2_weight_sf"]], + w2_alpha=[tensors["w2_alpha"]], + num_experts=num_experts, + top_k=top_k, + ) + + # Compare outputs - should be identical + passed, percent_within, atol = check_accuracy(list_output, single_output) + assert passed, ( + f"Single-B list backward compat mismatch: " + f"{percent_within * 100:.2f}% within tolerance (atol={atol:.4f})" + ) + + @pytest.mark.parametrize("num_b_tensors", [3, 4]) + def test_multi_b_n_tensors(self, num_b_tensors): + """Test multi-B with 3 or 4 tensors (exercises all const_expr branches). + + The kernel has MAX_B_TENSORS=4 and specialized dispatch code for each + count (1/2/3/4), so covering 3 and 4 here exercises the remaining paths. + """ + from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + + num_tokens, hidden_size, intermediate_size = 128, 512, 512 + # 12 experts for 3-way split (4+4+4), 8 experts for 4-way (2+2+2+2) + num_experts = 12 if num_b_tensors == 3 else 8 + top_k = 2 + + tensors = create_moe_tensors( + num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k + ) + + # Baseline: single tensor + single_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + num_experts=num_experts, + top_k=top_k, + ) + + # Even split into num_b_tensors pieces. SF sliced without .contiguous() + # to preserve the MMA-layout stride pattern. + assert num_experts % num_b_tensors == 0 + piece = num_experts // num_b_tensors + offsets = [i * piece for i in range(num_b_tensors + 1)] + + def split(t, on_last_dim): + return [ + (t[..., offsets[i]:offsets[i + 1]] if on_last_dim + else t[offsets[i]:offsets[i + 1]]) + for i in range(num_b_tensors) + ] + + multi_output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=split(tensors["w1_weight"], on_last_dim=False), + w1_weight_sf=split(tensors["w1_weight_sf"], on_last_dim=True), + w1_alpha=split(tensors["w1_alpha"], on_last_dim=False), + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=split(tensors["w2_weight"], on_last_dim=False), + w2_weight_sf=split(tensors["w2_weight_sf"], on_last_dim=True), + w2_alpha=split(tensors["w2_alpha"], on_last_dim=False), + num_experts=num_experts, + top_k=top_k, + ) + + passed, percent_within, atol = check_accuracy(multi_output, single_output) + assert passed, ( + f"Multi-B ({num_b_tensors} tensors) output mismatch: " + f"{percent_within * 100:.2f}% within tolerance (atol={atol:.4f})" + ) + + def test_multi_b_wrapper_api(self): + """Test CuteDslMoEWrapper with multi-B input.""" + from flashinfer.cute_dsl import CuteDslMoEWrapper + + num_tokens, hidden_size, intermediate_size = 128, 512, 512 + num_experts, top_k = 8, 2 + tensors = create_moe_tensors( + num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k + ) + + moe = CuteDslMoEWrapper( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + use_cuda_graph=False, + ) + + # Baseline + single_output = moe.run( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=tensors["w1_weight"], + w1_weight_sf=tensors["w1_weight_sf"], + w1_alpha=tensors["w1_alpha"], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=tensors["w2_weight"], + w2_weight_sf=tensors["w2_weight_sf"], + w2_alpha=tensors["w2_alpha"], + ) + + split = num_experts // 2 + multi_output = moe.run( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=[tensors["w1_weight"][:split], tensors["w1_weight"][split:]], + w1_weight_sf=[ + tensors["w1_weight_sf"][..., :split], + tensors["w1_weight_sf"][..., split:], + ], + w1_alpha=[tensors["w1_alpha"][:split], tensors["w1_alpha"][split:]], + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=[tensors["w2_weight"][:split], tensors["w2_weight"][split:]], + w2_weight_sf=[ + tensors["w2_weight_sf"][..., :split], + tensors["w2_weight_sf"][..., split:], + ], + w2_alpha=[tensors["w2_alpha"][:split], tensors["w2_alpha"][split:]], + ) + + passed, percent_within, atol = check_accuracy(multi_output, single_output) + assert passed, ( + f"Wrapper API multi-B mismatch: " + f"{percent_within * 100:.2f}% within tolerance (atol={atol:.4f})" + ) + + def test_multi_b_with_autotune(self): + """Test multi-B inside autotune() context. + + Regression test for tuner.get_valid_tactics: it must handle list-typed + weights without crashing on .shape[0]. + """ + from flashinfer.autotuner import autotune + from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + + num_tokens, hidden_size, intermediate_size = 128, 512, 512 + num_experts, top_k = 8, 2 + tensors = create_moe_tensors( + num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k + ) + + split = num_experts // 2 + w1 = [tensors["w1_weight"][:split], tensors["w1_weight"][split:]] + w1_sf = [ + tensors["w1_weight_sf"][..., :split], + tensors["w1_weight_sf"][..., split:], + ] + w1_a = [tensors["w1_alpha"][:split], tensors["w1_alpha"][split:]] + w2 = [tensors["w2_weight"][:split], tensors["w2_weight"][split:]] + w2_sf = [ + tensors["w2_weight_sf"][..., :split], + tensors["w2_weight_sf"][..., split:], + ] + w2_a = [tensors["w2_alpha"][:split], tensors["w2_alpha"][split:]] + + # Should not crash in tuning mode (i.e., get_valid_tactics handles lists). + with autotune(True): + output = cute_dsl_fused_moe_nvfp4( + x=tensors["x"], + x_sf=tensors["x_sf"], + token_selected_experts=tensors["token_selected_experts"], + token_final_scales=tensors["token_final_scales"], + w1_weight=w1, + w1_weight_sf=w1_sf, + w1_alpha=w1_a, + fc2_input_scale=tensors["fc2_input_scale"], + w2_weight=w2, + w2_weight_sf=w2_sf, + w2_alpha=w2_a, + num_experts=num_experts, + top_k=top_k, + ) + assert output.shape == (num_tokens, hidden_size) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 9edbe9be6d643f5fb7fa54ad98de45b5f9872c8b Mon Sep 17 00:00:00 2001 From: yhyang201 Date: Tue, 14 Apr 2026 10:30:40 +0000 Subject: [PATCH 2/3] fix: address PR review comments for DWDP multi-B support - Safe alpha indexing with pre-initialization before const_expr branches - NoneType guard: raise ValueError when b_tensor_l_sizes=None - Input validation for multi-B weight lists (empty, max 4, length match) - Fix test imports to use top-level flashinfer module Co-Authored-By: Claude Opus 4.6 (1M context) --- ...guous_gather_grouped_gemm_swiglu_fusion.py | 75 ++++++++++--------- ...contiguous_grouped_gemm_finalize_fusion.py | 75 ++++++++++--------- ...guous_gather_grouped_gemm_swiglu_fusion.py | 14 +++- ...contiguous_grouped_gemm_finalize_fusion.py | 14 +++- tests/moe/test_cute_dsl_fused_moe.py | 12 +-- 5 files changed, 110 insertions(+), 80 deletions(-) diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index f2edf7f968..b8a2447087 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -533,24 +533,25 @@ def __init__( self.vectorized_f32 = vectorized_f32 # Multi-B tensor configuration + # b_tensor_l_sizes is required — the Python wrapper layer always provides it + # as a tuple (even for single-B, e.g. (256,)). if b_tensor_l_sizes is None: - self.num_b_tensors = 1 - self.b_tensor_l_sizes = None - # Offsets padded for safe indexing in kernel - self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS - else: - assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( - f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + raise ValueError( + "b_tensor_l_sizes is required. Pass a tuple with the number of " + "experts per tensor, e.g. (num_experts,) for single-B." ) - self.num_b_tensors = len(b_tensor_l_sizes) - self.b_tensor_l_sizes = b_tensor_l_sizes - offsets = [0] - for l_size in b_tensor_l_sizes: - offsets.append(offsets[-1] + l_size) - # Pad to MAX_B_TENSORS + 1 for safe indexing - while len(offsets) < self.MAX_B_TENSORS + 1: - offsets.append(2**30) - self.b_tensor_l_offsets = tuple(offsets) + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -2973,44 +2974,50 @@ def kernel( # expert_idx = mma_tile_coord_mnl[2] - # Select alpha from correct tensor based on expert_idx - alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + # Select alpha from correct tensor based on expert_idx. + # Pre-initialize alpha_val for CuTe DSL type tracking — the DSL + # requires variables to have a consistent type before entering + # dynamic (runtime) if branches. Index 0 is always in-bounds. + alpha_val = alpha_tuple[0][0] if cutlass.const_expr(self.num_b_tensors == 1): - pass # Already initialized above + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] elif cutlass.const_expr(self.num_b_tensors == 2): - if expert_idx >= self.b_tensor_l_offsets[1]: + if expert_idx < self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[0][ + expert_idx - self.b_tensor_l_offsets[0] + ] + else: alpha_val = alpha_tuple[1][ expert_idx - self.b_tensor_l_offsets[1] ] elif cutlass.const_expr(self.num_b_tensors == 3): - if ( - expert_idx >= self.b_tensor_l_offsets[1] - and expert_idx < self.b_tensor_l_offsets[2] - ): + if expert_idx < self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[0][ + expert_idx - self.b_tensor_l_offsets[0] + ] + elif expert_idx < self.b_tensor_l_offsets[2]: alpha_val = alpha_tuple[1][ expert_idx - self.b_tensor_l_offsets[1] ] - elif expert_idx >= self.b_tensor_l_offsets[2]: + else: alpha_val = alpha_tuple[2][ expert_idx - self.b_tensor_l_offsets[2] ] else: # 4 B tensors - if ( - expert_idx >= self.b_tensor_l_offsets[1] - and expert_idx < self.b_tensor_l_offsets[2] - ): + if expert_idx < self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[0][ + expert_idx - self.b_tensor_l_offsets[0] + ] + elif expert_idx < self.b_tensor_l_offsets[2]: alpha_val = alpha_tuple[1][ expert_idx - self.b_tensor_l_offsets[1] ] - elif ( - expert_idx >= self.b_tensor_l_offsets[2] - and expert_idx < self.b_tensor_l_offsets[3] - ): + elif expert_idx < self.b_tensor_l_offsets[3]: alpha_val = alpha_tuple[2][ expert_idx - self.b_tensor_l_offsets[2] ] - elif expert_idx >= self.b_tensor_l_offsets[3]: + else: alpha_val = alpha_tuple[3][ expert_idx - self.b_tensor_l_offsets[3] ] diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index df9c78dcbb..b9588305c0 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -453,24 +453,25 @@ def __init__( self.tmem_final_offset = 384 # Multi-B tensor configuration + # b_tensor_l_sizes is required — the Python wrapper layer always provides it + # as a tuple (even for single-B, e.g. (256,)). if b_tensor_l_sizes is None: - self.num_b_tensors = 1 - self.b_tensor_l_sizes = None - # Offsets padded for safe indexing in kernel - self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS - else: - assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( - f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + raise ValueError( + "b_tensor_l_sizes is required. Pass a tuple with the number of " + "experts per tensor, e.g. (num_experts,) for single-B." ) - self.num_b_tensors = len(b_tensor_l_sizes) - self.b_tensor_l_sizes = b_tensor_l_sizes - offsets = [0] - for l_size in b_tensor_l_sizes: - offsets.append(offsets[-1] + l_size) - # Pad to MAX_B_TENSORS + 1 for safe indexing - while len(offsets) < self.MAX_B_TENSORS + 1: - offsets.append(2**30) - self.b_tensor_l_offsets = tuple(offsets) + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" + ) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -2388,44 +2389,50 @@ def kernel( expert_idx = mma_tile_coord_mnl[2] - # Select alpha from correct tensor based on expert_idx - alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] + # Select alpha from correct tensor based on expert_idx. + # Pre-initialize alpha_val for CuTe DSL type tracking — the DSL + # requires variables to have a consistent type before entering + # dynamic (runtime) if branches. Index 0 is always in-bounds. + alpha_val = alpha_tuple[0][0] if cutlass.const_expr(self.num_b_tensors == 1): - pass # Already initialized above + alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]] elif cutlass.const_expr(self.num_b_tensors == 2): - if expert_idx >= self.b_tensor_l_offsets[1]: + if expert_idx < self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[0][ + expert_idx - self.b_tensor_l_offsets[0] + ] + else: alpha_val = alpha_tuple[1][ expert_idx - self.b_tensor_l_offsets[1] ] elif cutlass.const_expr(self.num_b_tensors == 3): - if ( - expert_idx >= self.b_tensor_l_offsets[1] - and expert_idx < self.b_tensor_l_offsets[2] - ): + if expert_idx < self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[0][ + expert_idx - self.b_tensor_l_offsets[0] + ] + elif expert_idx < self.b_tensor_l_offsets[2]: alpha_val = alpha_tuple[1][ expert_idx - self.b_tensor_l_offsets[1] ] - elif expert_idx >= self.b_tensor_l_offsets[2]: + else: alpha_val = alpha_tuple[2][ expert_idx - self.b_tensor_l_offsets[2] ] else: # 4 B tensors - if ( - expert_idx >= self.b_tensor_l_offsets[1] - and expert_idx < self.b_tensor_l_offsets[2] - ): + if expert_idx < self.b_tensor_l_offsets[1]: + alpha_val = alpha_tuple[0][ + expert_idx - self.b_tensor_l_offsets[0] + ] + elif expert_idx < self.b_tensor_l_offsets[2]: alpha_val = alpha_tuple[1][ expert_idx - self.b_tensor_l_offsets[1] ] - elif ( - expert_idx >= self.b_tensor_l_offsets[2] - and expert_idx < self.b_tensor_l_offsets[3] - ): + elif expert_idx < self.b_tensor_l_offsets[3]: alpha_val = alpha_tuple[2][ expert_idx - self.b_tensor_l_offsets[2] ] - elif expert_idx >= self.b_tensor_l_offsets[3]: + else: alpha_val = alpha_tuple[3][ expert_idx - self.b_tensor_l_offsets[3] ] diff --git a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 6b27c0b763..54dda11f2b 100644 --- a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -416,9 +416,17 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( ... ) # out shape: (valid_m, intermediate_dim) """ # Normalize to lists for multi-B support - b_list = [b] if isinstance(b, torch.Tensor) else b - b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale - alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha + b_list = [b] if isinstance(b, torch.Tensor) else list(b) + b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else list(b_scale) + alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else list(alpha) + + # Validate multi-B inputs + assert len(b_list) > 0, "Weight tensor list must not be empty" + assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}" + assert len(b_list) == len(b_scale_list) == len(alpha_list), ( + f"b, b_scale, alpha lists must have same length: " + f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}" + ) # Validate inputs assert a.device.type == "cuda", "Input tensors must be on CUDA device" diff --git a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py index be4dad8876..75ee2ade5f 100644 --- a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py @@ -376,9 +376,17 @@ def blockscaled_contiguous_grouped_gemm_finalize_fusion_nvfp4( ... ) # out shape: (seq_len, hidden_dim) """ # Normalize to lists for multi-B support - b_list = [b] if isinstance(b, torch.Tensor) else b - b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale - alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha + b_list = [b] if isinstance(b, torch.Tensor) else list(b) + b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else list(b_scale) + alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else list(alpha) + + # Validate multi-B inputs + assert len(b_list) > 0, "Weight tensor list must not be empty" + assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}" + assert len(b_list) == len(b_scale_list) == len(alpha_list), ( + f"b, b_scale, alpha lists must have same length: " + f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}" + ) # Validate inputs assert a.device.type == "cuda", "Input tensors must be on CUDA device" diff --git a/tests/moe/test_cute_dsl_fused_moe.py b/tests/moe/test_cute_dsl_fused_moe.py index e764cac3b2..452b1e8838 100644 --- a/tests/moe/test_cute_dsl_fused_moe.py +++ b/tests/moe/test_cute_dsl_fused_moe.py @@ -1078,7 +1078,7 @@ def test_multi_b_fused_moe_2_tensors( self, num_tokens, hidden_size, intermediate_size, num_experts, top_k ): """Test end-to-end MoE with weights split into 2 tensors.""" - from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + from flashinfer import cute_dsl_fused_moe_nvfp4 tensors = create_moe_tensors( num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k @@ -1161,7 +1161,7 @@ def test_multi_b_uneven_split( self, num_tokens, hidden_size, intermediate_size, num_experts, top_k ): """Test multi-B with uneven expert split (e.g., 3+5 for 8 experts).""" - from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + from flashinfer import cute_dsl_fused_moe_nvfp4 tensors = create_moe_tensors( num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k @@ -1234,7 +1234,7 @@ def test_single_b_list_backward_compat( self, num_tokens, hidden_size, intermediate_size, num_experts, top_k ): """Test that passing a single tensor wrapped in a list produces identical results.""" - from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + from flashinfer import cute_dsl_fused_moe_nvfp4 tensors = create_moe_tensors( num_tokens, hidden_size, intermediate_size, num_experts, num_experts, top_k @@ -1288,7 +1288,7 @@ def test_multi_b_n_tensors(self, num_b_tensors): The kernel has MAX_B_TENSORS=4 and specialized dispatch code for each count (1/2/3/4), so covering 3 and 4 here exercises the remaining paths. """ - from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + from flashinfer import cute_dsl_fused_moe_nvfp4 num_tokens, hidden_size, intermediate_size = 128, 512, 512 # 12 experts for 3-way split (4+4+4), 8 experts for 4-way (2+2+2+2) @@ -1353,7 +1353,7 @@ def split(t, on_last_dim): def test_multi_b_wrapper_api(self): """Test CuteDslMoEWrapper with multi-B input.""" - from flashinfer.cute_dsl import CuteDslMoEWrapper + from flashinfer import CuteDslMoEWrapper num_tokens, hidden_size, intermediate_size = 128, 512, 512 num_experts, top_k = 8, 2 @@ -1418,7 +1418,7 @@ def test_multi_b_with_autotune(self): weights without crashing on .shape[0]. """ from flashinfer.autotuner import autotune - from flashinfer.cute_dsl import cute_dsl_fused_moe_nvfp4 + from flashinfer import cute_dsl_fused_moe_nvfp4 num_tokens, hidden_size, intermediate_size = 128, 512, 512 num_experts, top_k = 8, 2 From 5054cb10bbd400da4db4435cc0351ac5221d9fca Mon Sep 17 00:00:00 2001 From: yyh Date: Fri, 17 Apr 2026 20:59:27 +0800 Subject: [PATCH 3/3] fix: restore l parameter in wrapper for backward compat when b_tensor_l_sizes is None Co-Authored-By: Claude Opus 4.6 --- ...guous_gather_grouped_gemm_swiglu_fusion.py | 52 +++++++++++-------- ...guous_gather_grouped_gemm_swiglu_fusion.py | 6 ++- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index b8a2447087..43e97225e5 100644 --- a/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -410,8 +410,8 @@ def __init__( vectorized_f32: bool, topk: cutlass.Int64, raster_along_m: bool = False, - enable_pdl: bool = True, b_tensor_l_sizes: Optional[Tuple[int, ...]] = None, + enable_pdl: bool = True, ): """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with gather operation and SwiGLU fusion. @@ -533,25 +533,24 @@ def __init__( self.vectorized_f32 = vectorized_f32 # Multi-B tensor configuration - # b_tensor_l_sizes is required — the Python wrapper layer always provides it - # as a tuple (even for single-B, e.g. (256,)). if b_tensor_l_sizes is None: - raise ValueError( - "b_tensor_l_sizes is required. Pass a tuple with the number of " - "experts per tensor, e.g. (num_experts,) for single-B." + self.num_b_tensors = 1 + self.b_tensor_l_sizes = None + # Offsets padded for safe indexing in kernel + self.b_tensor_l_offsets = (0,) + (2**30,) * self.MAX_B_TENSORS + else: + assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( + f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" ) - assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ( - f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}" - ) - self.num_b_tensors = len(b_tensor_l_sizes) - self.b_tensor_l_sizes = b_tensor_l_sizes - offsets = [0] - for l_size in b_tensor_l_sizes: - offsets.append(offsets[-1] + l_size) - # Pad to MAX_B_TENSORS + 1 for safe indexing - while len(offsets) < self.MAX_B_TENSORS + 1: - offsets.append(2**30) - self.b_tensor_l_offsets = tuple(offsets) + self.num_b_tensors = len(b_tensor_l_sizes) + self.b_tensor_l_sizes = b_tensor_l_sizes + offsets = [0] + for l_size in b_tensor_l_sizes: + offsets.append(offsets[-1] + l_size) + # Pad to MAX_B_TENSORS + 1 for safe indexing + while len(offsets) < self.MAX_B_TENSORS + 1: + offsets.append(2**30) + self.b_tensor_l_offsets = tuple(offsets) def _setup_attributes(self): """Set up configurations that are dependent on GEMM inputs @@ -4034,6 +4033,7 @@ def wrapper( m: cutlass.Int64, n: cutlass.Int64, k: cutlass.Int64, + l: cutlass.Int64, # noqa: E741 tile_size: cutlass.Constexpr, scaling_vector_size: cutlass.Constexpr, max_active_clusters: cutlass.Constexpr, @@ -4043,12 +4043,19 @@ def wrapper( """Unified wrapper supporting both single-B and multi-B tensors. B tensors are always passed as tuples (length 1 for single-B). - L sizes are configured via b_tensor_l_sizes in __init__. + When b_tensor_l_sizes is provided, L sizes come from b_tensor_l_sizes; + otherwise falls back to the l parameter (backward compatible single-B). """ scale_k = k // scaling_vector_size interm_size = n // 2 num_tiles = m // tile_size - total_l = self.b_tensor_l_offsets[self.num_b_tensors] + # When b_tensor_l_sizes is provided, total_l comes from the precomputed offsets + # and l is ignored. Callers must ensure l == sum(b_tensor_l_sizes). + # When b_tensor_l_sizes is None (single-B backward compat), l is used directly. + if cutlass.const_expr(self.b_tensor_l_sizes is not None): + total_l = self.b_tensor_l_offsets[self.num_b_tensors] + else: + total_l = l a = cute.make_tensor( a_ptr, layout=cute.make_ordered_layout((orig_m, k, 1), order=(1, 0, 2)) @@ -4069,7 +4076,10 @@ def wrapper( ) # Create B and alpha tensors using const_expr conditions - l_0 = self.b_tensor_l_sizes[0] + if cutlass.const_expr(self.b_tensor_l_sizes is not None): + l_0 = self.b_tensor_l_sizes[0] + else: + l_0 = l alpha_0 = cute.make_tensor(alpha_ptr_tuple[0], layout=cute.make_layout((l_0,))) b_0 = cute.make_tensor( b_ptr_tuple[0], diff --git a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py index 54dda11f2b..95e6453c37 100644 --- a/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py +++ b/flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -274,8 +274,9 @@ def _get_compiled_gather_kernel( # Order must match wrapper signature: # (a_ptr, b_ptr_tuple, a_sf_ptr, b_sf_ptr_tuple, c_ptr, c_sf_ptr, alpha_ptr_tuple, # tile_idx_to_group_idx_ptr, tile_idx_to_mn_limit_ptr, token_id_mapping_ptr, - # num_non_exiting_tiles_ptr, norm_const_ptr, orig_m, m, n, k, + # num_non_exiting_tiles_ptr, norm_const_ptr, orig_m, m, n, k, l, # tile_size, scaling_vector_size, max_active_clusters, stream) + num_experts = sum(b_tensor_l_sizes) compile_args = [ a_ptr, b_ptr, @@ -293,6 +294,7 @@ def _get_compiled_gather_kernel( permuted_m, n, k, + num_experts, ] compiled_gemm = cute.compile( @@ -620,6 +622,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( ) # Execute kernel + num_experts = sum(b_tensor_l_sizes) exec_args = [ a_ptr, b_ptr, @@ -637,6 +640,7 @@ def blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion_nvfp4( permuted_m, n, k, + num_experts, # l ] compiled_gemm(*exec_args, stream=stream)