diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 311db068b50..c15d2ac0817 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1439,7 +1439,7 @@ repos: additional_dependencies: - tomli # add ignore words list - args: ["-L", "Mor,ans,thirdparty", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"] + args: ["-L", "Mor,ans,thirdparty,subtiles", "--skip", "ATTRIBUTIONS-*.md,*.svg", "--skip", "security_scanning/*"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.4 hooks: diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 703dcc430a5..b346219f5bc 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -21,6 +21,14 @@ class GroupedGemmInputsHelper: + """Base helper class for grouped GEMM input preparation and tuning. + + Subclasses should override IDX_SHAPE_INFER to specify which input tensor + to use for shape inference in tuning. + """ + # Input tensor index for shape inference - subclass can override + IDX_A = 0 + IDX_SHAPE_INFER = IDX_A # Default: use a tensor for shape inference def __init__(self, num_experts: int, top_k: int, num_local_experts: int, local_expert_offset: int, tile_size: int): @@ -63,10 +71,11 @@ def map_to_tuning_buckets(self, x: int) -> int: last_positive_power_of_2(self.infer_num_tokens(x))) def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: - return self.infer_num_tokens(input_shapes[0][0]) + return self.infer_num_tokens(input_shapes[self.IDX_SHAPE_INFER][0]) def infer_shape_max_num_tiles(self, input_shapes: List[torch.Size]) -> int: - return input_shapes[0][0] // self.tile_size + """Infer max_num_tiles from the shape inference tensor (IDX_SHAPE_INFER).""" + return input_shapes[self.IDX_SHAPE_INFER][0] // self.tile_size def infer_shape_max_num_permuted_tokens( self, input_shapes: List[torch.Size]) -> int: @@ -187,6 +196,123 @@ def inputs_pre_hook_finalize_fusion( return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales +class GatherGroupedGemmInputsHelper(GroupedGemmInputsHelper): + """Helper class for gather-based grouped GEMM input preparation. + + This subclass handles inputs where: + - a tensor contains original (non-permuted) activations + - permuted_idx_to_expanded_idx specifies the gather pattern + - Shape inference uses permuted_idx_to_expanded_idx size instead of a size + + Input tensor layout: + 0: a - Original input activation (not permuted) + 1: b - Weight tensor + 2: a_sf - Scale factor for a + 3: b_sf - Scale factor for b + 4: alpha - Per-expert scaling factor + 5: tile_idx_to_group_idx - Tile to expert mapping + 6: tile_idx_to_mn_limit - Tile M/N limits + 7: permuted_idx_to_expanded_idx - Token permutation mapping + 8: num_non_exiting_tiles - Number of valid tiles + 9: global_sf - Global scale factor + """ + # Override: use permuted_idx_to_expanded_idx for shape inference + IDX_PERMUTED_IDX_TO_EXPANDED_IDX = 7 + IDX_SHAPE_INFER = IDX_PERMUTED_IDX_TO_EXPANDED_IDX + + def generate_permuted_idx_to_expanded_idx( + self, num_tokens: int, num_tokens_per_expert: List[int], + max_num_permuted_tokens: int) -> List[int]: + """Generate permuted_idx_to_expanded_idx for gather operation. + + Maps permuted index to expanded index (token_idx * top_k + topk_idx). + + Args: + num_tokens: Total number of input tokens + num_tokens_per_expert: List of token counts per expert + max_num_permuted_tokens: Target size of the output list + + Returns: + List of expanded IDs with length = max_num_permuted_tokens, + where permuted_idx_to_expanded_idx[permuted_idx] = expanded_idx + Padding tokens are marked with pad_val + Note: In kernel, use expanded_idx // top_k to get original token_idx + """ + permuted_idx_to_expanded_idx = [] + colmajor_expanded_idx = 0 + for i, curr_num_tokens in enumerate(num_tokens_per_expert): + curr_num_tiles = (curr_num_tokens + self.tile_size - + 1) // self.tile_size + for j in range(curr_num_tiles * self.tile_size): + if j < curr_num_tokens: + token_idx = colmajor_expanded_idx % num_tokens + topk_idx = colmajor_expanded_idx // num_tokens + expanded_idx = token_idx * self.top_k + topk_idx + permuted_idx_to_expanded_idx.append(expanded_idx) + colmajor_expanded_idx += 1 + else: + permuted_idx_to_expanded_idx.append( + self.pad_val) # Padding token + # Pad to max_num_permuted_tokens + while len(permuted_idx_to_expanded_idx) < max_num_permuted_tokens: + permuted_idx_to_expanded_idx.append(self.pad_val) + return permuted_idx_to_expanded_idx + + def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: + """Pre-hook for gather-based SwiGLU fusion kernel. + + Generates: + - tile_idx_to_group_idx + - tile_idx_to_mn_limit + - permuted_idx_to_expanded_idx (for gather operation) + - num_non_exiting_tiles + """ + a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, \ + permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf = inputs + # Verify permuted_idx_to_expanded_idx index matches the class constant + assert inputs[ + self. + IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx + + max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0) + max_num_tiles = max_num_permuted_tokens // self.tile_size + + num_tokens = self.infer_num_tokens(max_num_permuted_tokens) + num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens) + tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx( + num_tokens_per_expert) + tile_idx_to_mn_limit_list = self.generate_tile_idx_to_mn_limit( + num_tokens_per_expert) + permuted_idx_to_expanded_idx_list = self.generate_permuted_idx_to_expanded_idx( + num_tokens, num_tokens_per_expert, max_num_permuted_tokens) + num_non_exiting_tiles_val = len(tile_idx_to_group_idx_list) + num_padding_tiles_val = max_num_tiles - num_non_exiting_tiles_val + assert num_non_exiting_tiles_val > 0 + assert num_padding_tiles_val >= 0 + assert len(tile_idx_to_mn_limit_list) == num_non_exiting_tiles_val + assert len(permuted_idx_to_expanded_idx_list) == max_num_permuted_tokens + + tile_idx_to_group_idx = torch.tensor( + tile_idx_to_group_idx_list + [self.pad_val] * num_padding_tiles_val, + dtype=tile_idx_to_group_idx.dtype, + device=tile_idx_to_group_idx.device) + tile_idx_to_mn_limit = torch.tensor( + tile_idx_to_mn_limit_list + [self.pad_val] * num_padding_tiles_val, + dtype=tile_idx_to_mn_limit.dtype, + device=tile_idx_to_mn_limit.device) + permuted_idx_to_expanded_idx = torch.tensor( + permuted_idx_to_expanded_idx_list, + dtype=permuted_idx_to_expanded_idx.dtype, + device=permuted_idx_to_expanded_idx.device) + num_non_exiting_tiles = torch.tensor( + [num_non_exiting_tiles_val], + dtype=num_non_exiting_tiles.dtype, + device=num_non_exiting_tiles.device) + return (a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, + tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, + num_non_exiting_tiles, global_sf) + + class FusedMoEInputsHelper: def __init__(self, num_experts: int, top_k: int, num_local_experts: int, @@ -217,6 +343,8 @@ def inputs_pre_hook(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]: import cutlass import cutlass.cute as cute + from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion import \ + BlockScaledContiguousGatherGroupedGemmKernel from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm import \ Sm100BlockScaledContiguousGroupedGemmKernel from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_finalize_fusion import \ @@ -1612,6 +1740,390 @@ def _( device=input_scale.device) return output, output_scale + class Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( + TunableRunner): + kernel_class = BlockScaledContiguousGatherGroupedGemmKernel + kernel_cache = dict() + tuning_config_cache = dict() + + def __init__(self, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.num_local_experts = num_local_experts + self.local_expert_offset = local_expert_offset + if tile_size not in [128, 256]: + raise ValueError( + f"Tile size {tile_size} is not supported, it only supports 128 and 256." + ) + self.tile_size = tile_size + self.scaling_vector_size = scaling_vector_size + + if get_sm_version() != 100 and get_sm_version() != 103: + raise ValueError( + f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100 and SM 103" + ) + + def unique_id(self): + return ( + self.num_experts, + self.top_k, + self.num_local_experts, + self.local_expert_offset, + self.tile_size, + self.scaling_vector_size, + ) + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[Tuple[int, int]]: + a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, *_ = inputs + # m is the permuted size from permuted_idx_to_expanded_idx, not from a + m = permuted_idx_to_expanded_idx.size(0) + k = a.size(1) * 2 + l, n = b.size(0), b.size(1) + + if self.tile_size == 128: + mma_tiler_mn_candidates = [(128, 128), (128, 256)] + cluster_shape_mn_candidates = [(1, 1)] + elif self.tile_size == 256: + mma_tiler_mn_candidates = [(256, 128), (256, 256)] + cluster_shape_mn_candidates = [(2, 1)] + else: + raise ValueError(f"Tile size {self.tile_size} is not supported") + + valid_tactics = [] + for mma_tiler_mn, cluster_shape_mn in itertools.product( + mma_tiler_mn_candidates, cluster_shape_mn_candidates): + if self.__class__.kernel_class.can_implement( + ab_dtype=cutlass.Float4E2M1FN, + sf_dtype=cutlass.Float8E4M3FN, + sf_vec_size=self.scaling_vector_size, + acc_dtype=cutlass.Float32, + c_dtype=cutlass.Float4E2M1FN, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + m=m, + n=n, + k=k, + l=l, + a_major="k", + b_major="k", + c_major="n", + m_aligned=self.tile_size, + ): + valid_tactics.append((mma_tiler_mn, cluster_shape_mn)) + + return valid_tactics + + def get_tuning_config(self) -> TuningConfig: + key = self.unique_id() + if key not in self.__class__.tuning_config_cache: + helper = GatherGroupedGemmInputsHelper(self.num_experts, + self.top_k, + self.num_local_experts, + self.local_expert_offset, + self.tile_size) + self.__class__.tuning_config_cache[key] = TuningConfig( + # Use permuted_idx_to_expanded_idx (IDX_SHAPE_INFER) for tuning + dynamic_tensor_specs=(DynamicTensorSpec( + GatherGroupedGemmInputsHelper.IDX_SHAPE_INFER, 0, + helper.gen_tuning_buckets, + helper.map_to_tuning_buckets), ), + constraint_specs=(ConstraintSpec( + 0, 0, helper.infer_shape_num_tokens), + ConstraintSpec( + 2, 0, helper.infer_shape_num_tokens), + ConstraintSpec( + 5, 0, + helper.infer_shape_max_num_tiles), + ConstraintSpec( + 6, 0, + helper.infer_shape_max_num_tiles)), + inputs_pre_hook=helper.inputs_pre_hook, + use_cuda_graph=True, + ) + return self.__class__.tuning_config_cache[key] + + def forward(self, inputs: List[torch.Tensor], + tactic: Optional[tuple]) -> torch.Tensor: + a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf = inputs + # Verify permuted_idx_to_expanded_idx index matches the class constant + assert inputs[ + GatherGroupedGemmInputsHelper. + IDX_PERMUTED_IDX_TO_EXPANDED_IDX] is permuted_idx_to_expanded_idx + assert a.dtype == torch.float4_e2m1fn_x2 + assert a.dim() == 2 + assert b.dtype == torch.float4_e2m1fn_x2 + assert b.dim() == 3 + assert a_sf.dtype == torch.uint8 + assert a_sf.dim() == 2 + assert b_sf.dtype == torch.uint8 + assert b_sf.dim() == 3 + assert alpha.dtype == torch.float32 + assert alpha.dim() == 1 + + # a.size(0) is orig_m (original input size before gather) + # permuted_idx_to_expanded_idx.size(0) is m (permuted size after gather) + orig_m, k = a.size(0), a.size(1) * 2 + m = permuted_idx_to_expanded_idx.size(0) + l, n = b.size(0), b.size(1) + scale_k = k // self.scaling_vector_size + interm_size = n // 2 + assert m % self.tile_size == 0 + assert k % (self.scaling_vector_size * 4) == 0 + assert n % (self.scaling_vector_size * 4 * 2) == 0 + assert b.size(2) * 2 == k + assert a_sf.size(0) == orig_m + assert a_sf.size(1) == scale_k + assert b_sf.size(0) == l + assert b_sf.size(1) == n + assert b_sf.size(2) == scale_k + assert alpha.size(0) == l + + num_tiles = m // self.tile_size + assert tile_idx_to_group_idx.dtype == torch.int32 + assert tile_idx_to_group_idx.size() == (num_tiles, ) + assert tile_idx_to_mn_limit.dtype == torch.int32 + assert tile_idx_to_mn_limit.size() == (num_tiles, ) + assert permuted_idx_to_expanded_idx.dtype == torch.int32 + assert permuted_idx_to_expanded_idx.size() == (m, ) + assert num_non_exiting_tiles.dtype == torch.int32 + assert num_non_exiting_tiles.numel() == 1 + assert global_sf.dtype == torch.float32 + assert global_sf.numel() == 1 + + c = torch.empty(m, interm_size // 2, dtype=a.dtype, device=a.device) + c_sf = torch.empty(m * interm_size // self.scaling_vector_size, + dtype=a_sf.dtype, + device=a_sf.device) + + a_ptr = make_ptr(cutlass.Float4E2M1FN, + a.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + b_ptr = make_ptr(cutlass.Float4E2M1FN, + b.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + a_sf_ptr = make_ptr(cutlass.Float8E4M3FN, + a_sf.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) + b_sf_ptr = make_ptr(cutlass.Float8E4M3FN, + b_sf.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) + alpha_ptr = make_ptr(cutlass.Float32, alpha.data_ptr(), + cute.AddressSpace.gmem) + tile_idx_to_group_idx_ptr = make_ptr( + cutlass.Int32, tile_idx_to_group_idx.data_ptr(), + cute.AddressSpace.gmem) + tile_idx_to_mn_limit_ptr = make_ptr(cutlass.Int32, + tile_idx_to_mn_limit.data_ptr(), + cute.AddressSpace.gmem) + permuted_idx_to_expanded_idx_ptr = make_ptr( + cutlass.Int32, permuted_idx_to_expanded_idx.data_ptr(), + cute.AddressSpace.gmem) + num_non_exiting_tiles_ptr = make_ptr( + cutlass.Int32, num_non_exiting_tiles.data_ptr(), + cute.AddressSpace.gmem) + global_sf_ptr = make_ptr(cutlass.Float32, global_sf.data_ptr(), + cute.AddressSpace.gmem) + c_ptr = make_ptr(cutlass.Float4E2M1FN, + c.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=32) + c_sf_ptr = make_ptr(cutlass.Float8E4M3FN, + c_sf.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16) + + torch_stream = torch.cuda.current_stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + + if isinstance(tactic, tuple): + mma_tiler_mn, cluster_shape_mn = tactic + else: + mma_tiler_mn, cluster_shape_mn = (self.tile_size, + 128), (self.tile_size // 128, + 1) + + cache_key = (self.scaling_vector_size, self.tile_size, self.top_k, + mma_tiler_mn, cluster_shape_mn) + if cache_key not in self.__class__.kernel_cache: + gemm = self.__class__.kernel_class( + sf_vec_size=self.scaling_vector_size, + acc_dtype=cutlass.Float32, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + vectorized_f32=True, + topk=self.top_k, + ) + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + + compiled_gemm = cute.compile( + gemm.wrapper, + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_ptr, + c_sf_ptr, + alpha_ptr, + tile_idx_to_group_idx_ptr, + tile_idx_to_mn_limit_ptr, + permuted_idx_to_expanded_idx_ptr, + num_non_exiting_tiles_ptr, + global_sf_ptr, + orig_m, + m, + n, + k, + l, + tile_size=self.tile_size, + scaling_vector_size=self.scaling_vector_size, + max_active_clusters=max_active_clusters, + stream=stream, + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + compiled_gemm( + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_ptr, + c_sf_ptr, + alpha_ptr, + tile_idx_to_group_idx_ptr, + tile_idx_to_mn_limit_ptr, + permuted_idx_to_expanded_idx_ptr, + num_non_exiting_tiles_ptr, + global_sf_ptr, + orig_m, + m, + n, + k, + l, + stream=stream, + ) + return c, c_sf + + @torch.library.custom_op( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + global_sf: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tuner = AutoTuner.get() + + runner = Sm100BlockScaledContiguousGatherGroupedGemmSwigluFusionRunner( + num_experts, top_k, num_local_experts, local_expert_offset, + tile_size, scaling_vector_size) + inputs = [ + input, weight, input_scale, weight_scale, alpha, + tile_idx_to_group_idx, tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, num_non_exiting_tiles, global_sf + ] + + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell", + [runner], + runner.get_tuning_config(), + inputs, + ) + output = runner(inputs, tactic=best_tactic) + return output + + @torch.library.register_fake( + "trtllm::cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell") + def _( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + tile_idx_to_group_idx: torch.Tensor, + tile_idx_to_mn_limit: torch.Tensor, + permuted_idx_to_expanded_idx: torch.Tensor, + num_non_exiting_tiles: torch.Tensor, + global_sf: torch.Tensor, + num_experts: int, + top_k: int, + num_local_experts: int, + local_expert_offset: int, + tile_size: int, + scaling_vector_size: int = 16, + ) -> Tuple[torch.Tensor, torch.Tensor]: + m = permuted_idx_to_expanded_idx.size(0) + n = weight.size(1) + interm_size = n // 2 + output = torch.empty(m, + interm_size // 2, + dtype=input.dtype, + device=input.device) + output_scale = torch.empty(m * interm_size // scaling_vector_size, + dtype=input_scale.dtype, + device=input_scale.device) + return output, output_scale + + class FusedMoEInputsHelper: + + def __init__(self, num_experts: int, top_k: int, num_local_experts: int, + local_expert_offset: int): + self.num_experts = num_experts + self.top_k = top_k + self.num_local_experts = num_local_experts + self.local_expert_offset = local_expert_offset + + def infer_shape_num_tokens(self, input_shapes: List[torch.Size]) -> int: + return input_shapes[0][0] + + def inputs_pre_hook(self, + inputs: List[torch.Tensor]) -> List[torch.Tensor]: + x, x_sf, token_selected_experts, token_final_scales, *others = inputs + num_tokens = token_selected_experts.size(0) + new_token_final_scales, new_token_selected_experts = torch.randn( + num_tokens, + self.num_experts, + device=token_selected_experts.device).topk(self.top_k, dim=-1) + new_token_selected_experts = new_token_selected_experts.to( + token_selected_experts.dtype) + new_token_final_scales = new_token_final_scales.softmax(dim=-1).to( + token_final_scales.dtype) + return x, x_sf, new_token_selected_experts, new_token_final_scales, *others + class Sm100BlockScaledFusedMoERunner(TunableRunner): tuning_config_cache = dict() diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py new file mode 100644 index 00000000000..5b9d06bb178 --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py @@ -0,0 +1,3029 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Optional, Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass._mlir.dialects import math, nvvm +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.cute.typing import Float32 +from cutlass.cutlass_dsl import T, dsl_user_op + +from .custom_pipeline import PipelineCpAsyncUmma +from .utils import is_power_of_2 + + +@dsl_user_op +def fmin( + a: Union[float, Float32], b: Union[float, Float32], *, nan=False, loc=None, ip=None +) -> Float32: + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + nan=nan, + loc=loc, + ip=ip, + ) + ) + + +def sigmoid_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: + """ + Compute the sigmoid of the input tensor. + """ + return cute.arch.rcp_approx(1.0 + cute.math.exp(-a, fastmath=fastmath)) + + +def silu_f32(a: Union[float, Float32], fastmath: bool = False) -> Union[float, Float32]: + """ + Compute the silu of the input tensor. + """ + return a * sigmoid_f32(a, fastmath=fastmath) + + +""" +High-performance persistent blockscaled contiguous grouped dense GEMM with gather and SwiGLU fusion +(C = up * silu(gate), where up and gate come from interleaved weight matrix B) +example for the NVIDIA Blackwell architecture using CUTE DSL. + +This kernel performs FC1 layer computation with SwiGLU activation fusion: +1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) +2. SwiGLU: C = up * silu(gate), where up/gate are extracted from interleaved acc (granularity=64) +3. Optional Quant: When c_dtype is Float4E2M1FN, generates scale factor C and quantizes output + +- Matrix A is MxKx1, A can be row-major("K"), ValidM is composed of valid m in different groups +- Matrix B is NxKxL, B can be column-major("K"), L is grouped dimension (number of experts) + - B weights are interleaved: [up_0:64, gate_64:128, up_128:192, gate_192:256, ...] +- Matrix C is Mx(N/2)x1, C can be row-major("N"), N is halved due to SwiGLU fusion +- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, + which has M×ceil_div(K, sf_vec_size)×1 elements +- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, + which has N×ceil_div(K, sf_vec_size)×L elements +- Token ID mapping tensor enables gather operation for A and SFA + +Matrix A/C Memory Layout Diagrams: + + ``` + Group 0 Group 1 Group 2 + -+---------+---------+---------+ + | | | | + K| ValidM0 | ValidM1 | ValidM2 | + | | | | + -+---------+---------+---------+ + |<- ValidM ->| + ``` + Note: the Group(L) dimension will be flatted into M dimension, and the rest Group(L) size is 1. + each ValidM will be aligned to 256 or 128. The alignment is determined by the mma_tiler_mn parameter. + For NVFP4, 2CTA, the alignment is 256. For NVFP4, 1CTA, the alignment is 128. + +This GEMM kernel supports the following features: + - Utilizes LDGSTS (Load Global to Shared with Swizzle) for A and SFA with gather operation + - Utilizes Tensor Memory Access (TMA) for B and SFB matrices + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. SCHEDULER warp (warp 10): Dispatches tile information to all consumer warps via tile_info_pipeline. +2. LDGSTS A/SFA warps (warps 4-7): + - Load A matrix from global memory (GMEM) to shared memory (SMEM) using LDGSTS instructions with gather. + - Load SFA (scale factor A) from GMEM to SMEM using LDGSTS instructions. + - Uses token_id_mapping to perform permutation/gather during load. +3. TMA B/SFB warp (warp 9): + - Load B and SFB matrices from GMEM to SMEM using TMA operations with multicast. +4. MMA warp (warp 8): + - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. + - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +5. EPILOGUE warps (warps 0-3): + - Load two accumulator subtiles (up and gate) from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply alpha scaling: up_scaled = alpha * up, gate_scaled = alpha * gate + - Compute SwiGLU activation: output = up_scaled * silu(gate_scaled), where silu(x) = x * sigmoid(x) + - If c_dtype is Float4E2M1FN: generate scale factor C (SFC) and quantize output + - Type convert output to c_dtype. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma.kind.block_scale instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Read scalefactor A from TMEM +- Read scalefactor B from TMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 64/128/192/256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. + +CUDA Graph Support: +* For CUDA graph support, the tile_idx_to_expert_idx, token_id_mapping, A/C matrices, + and scale factor A can be padded to a larger size + (e.g., permuted_m = m*topK + num_local_experts*(256-1), + example: 4096*8 + (256/32)*255 = 34808) +* Use create_tensors() with permuted_m parameter to automatically pad: + - tile_idx_to_expert_idx: padded for invalid tiles (set to -2e9 for padding tiles) + - token_id_mapping: padded to permuted_m size (invalid tokens set to -1) + - A matrix: padded to permuted_m rows (padding rows contain dummy data) + - C matrix: padded to permuted_m rows (output buffer for cuda_graph) + - Scale factor A: padded to match A matrix dimensions +* Kernel handling of padding: + - Scheduler warp checks if tile_idx >= num_non_exiting_tiles to exit + - Only valid tiles (tile_idx < num_non_exiting_tiles) are written to tile_info pipeline + - LDGSTS warps use token_id_mapping predicates to skip invalid tokens (token_id == -1) + - When no more valid tiles exist, outer loop exits and calls producer_tail() + - Consumer warps process only valid tiles from pipeline + - No deadlock or synchronization issues +* Consumer warps check initial tile against num_non_exiting_tiles and set + is_valid_tile=False if tile_idx >= num_non_exiting_tiles +* Only rows within (aligned_groupm[0]+aligned_groupm[1]+...) contain valid data +* Padding rows in C matrix will not be written by the kernel +""" + + +class BlockScaledContiguousGatherGroupedGemmKernel: + """This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion + for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result). + + The computation flow: + 1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) + 2. SwiGLU: C = up * silu(gate), extracted from interleaved acc with granularity=64 + 3. Optional Quant: When c_dtype is Float4E2M1FN, generates SFC and quantizes output + + Note: Output C has N/2 columns since pairs of (up, gate) are combined by SwiGLU. + + Key Features: + - Uses LDGSTS instructions for loading A and SFA matrices with gather/permutation capability + - Uses TMA (Tensor Memory Access) for loading B and SFB matrices with multicast + - Token ID mapping enables efficient gather operation during A/SFA load + - SwiGLU activation fusion in epilogue (up * silu(gate) with interleaved weights) + - Optional quantization fusion for Float4E2M1FN output with scale factor generation + - Warp specialization: Scheduler (warp 10), LDGSTS A/SFA (warps 4-7), TMA B/SFB (warp 9), + MMA (warp 8), Epilogue (warps 0-3) + + :param sf_vec_size: Scalefactor vector size (16 for NVF4, 32 for MXF4/MXF8). + :type sf_vec_size: int + :param acc_dtype: Data type of the accumulator (e.g., cutlass.Float32). + :type acc_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N). + Note: use_2cta_instrs is automatically inferred from mma_tiler_mn[0] + (True when M=256, False when M=128). + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param vectorized_f32: Whether to use vectorized f32x2 operations for better performance. + :type vectorized_f32: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + # {$nv-internal-release begin} + # Note: Float4E2M1FN output includes SFC generation and quantization support for internal testing. + - Float4E2M1FN (with scale factor generation) + # {$nv-internal-release end} + + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 64/128/192/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> # Note: use_2cta_instrs is auto-inferred from mma_tiler_mn[0] + >>> # (True when M=256, False when M=128) + >>> gemm = BlockScaledContiguousGatherGroupedGemmKernel( + ... sf_vec_size=16, + ... acc_dtype=cutlass.Float32, + ... mma_tiler_mn=(256, 128), # use_2cta_instrs=True since M=256 + ... cluster_shape_mn=(2, 1), + ... vectorized_f32=True, + ... ) + >>> gemm( + ... a=a_tensor, + ... b=b_tensor, + ... c=c_tensor, + ... sfa=sfa_tensor, + ... sfb=sfb_tensor, + ... sfc_tensor=None, + ... norm_const_tensor=None, + ... tile_idx_to_expert_idx=tile_idx_to_expert_idx, + ... tile_idx_to_mn_limit=tile_idx_to_mn_limit, + ... token_id_mapping_tensor=token_id_mapping_tensor, + ... num_non_exiting_tiles=num_non_exiting_tiles, + ... alpha=alpha, + ... max_active_clusters=max_active_clusters, + ... stream=stream, + ... ) + """ + + def __init__( + self, + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + vectorized_f32: bool, + topk: int, + ): + """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with + gather operation and SwiGLU fusion. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Automatically inferred from mma_tiler_mn[0] + (True when M=256, False when M=128). + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Scale Factor Configuration: + - sf_vec_size: Vector size for block-scaled quantization. + + 4. Performance Optimization: + - vectorized_f32: Enable vectorized f32x2 operations. + + 5. MoE Configuration: + - topk: Number of experts selected per token (used for token ID mapping). + + :param sf_vec_size: Vector size for scale factors (16 for NVF4, 32 for MXF4/MXF8). + :type sf_vec_size: int + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + use_2cta_instrs is automatically set based on M (True if M=256, False if M=128). + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param vectorized_f32: Enable vectorized f32x2 operations for better performance. + :type vectorized_f32: bool + :param topk: Number of experts selected per token (used for token ID mapping). + :type topk: int + """ + + self.sf_vec_size = sf_vec_size + self.topk = topk + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + self.epilog_warp_id = (0, 1, 2, 3) + self.ldgsts_a_warp_id = ( + 4, + 5, + 6, + 7, + ) + self.mma_warp_id = 8 + self.tma_b_warp_id = 9 + self.sched_warp_id = 10 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + self.mma_warp_id, + *self.ldgsts_a_warp_id, + self.tma_b_warp_id, + *self.epilog_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_b_warp_id, + *self.ldgsts_a_warp_id, + ) + ) + + # Set barrier for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=self.threads_per_cta, + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=32 * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=4, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + self.vectorized_f32 = vectorized_f32 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + + self.mma_inst_shape_mn = ( + self.mma_tiler[0], + self.mma_tiler[1], + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mn_sfb = ( + self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mn[1], 128), + ) + + # Configure tiled mma + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.mma_tiler_sfa = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1], + mma_inst_shape_k * mma_inst_tile_k // 16, + ) + + self.mma_tiler_sfb = ( + self.mma_inst_shape_mn_sfb[0], + self.mma_inst_shape_mn_sfb[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + self.cta_tile_shape_mnk_sfa = ( + self.mma_tiler_sfa[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_sfa[1], + self.mma_tiler_sfa[2], + ) + + self.mma_tiler_c = ( + self.mma_inst_shape_mn[0], + self.mma_inst_shape_mn[1] // 2, + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_tile_shape_mnk_c = ( + self.mma_tiler_c[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler_c[1], + self.mma_tiler_c[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = (128, 64) + self.epi_tile_cnt = ( + self.cta_tile_shape_mnk_c[0] // self.epi_tile[0], + self.cta_tile_shape_mnk_c[1] // self.epi_tile[1], + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + sfc_tensor: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + tile_idx_to_expert_idx: cute.Tensor, + tile_idx_to_mn_limit: cute.Tensor, + token_id_mapping_tensor: cute.Tensor, + num_non_exiting_tiles: cute.Tensor, + alpha: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the contiguous grouped GEMM with gather operation and SwiGLU fusion. + + This method performs FC1 layer computation: + 1. GEMM: acc = alpha * (SFA * A[token_ids]) * (SFB * B) + 2. SwiGLU: C = up * silu(gate), where up/gate are extracted from interleaved acc (granularity=64) + 3. Optional Quant: When c_dtype is Float4E2M1FN, generates SFC and quantizes output + + Data loading: + - A and SFA are loaded using LDGSTS instructions with token-based gather + - B and SFB are loaded using TMA instructions with multicast + - B weights are interleaved: [up_0:64, gate_64:128, up_128:192, gate_192:256, ...] + + Execution steps: + 1. Setup static attributes before smem/grid computation + 2. Setup TMA load/store atoms for B, SFB, and C (no TMA for A/SFA) + 3. Compute grid size with regard to hardware constraints + 4. Define shared storage for kernel + 5. Launch the kernel synchronously with warp specialization: + - Scheduler warp: Dispatches tile information + - LDGSTS warps: Load A and SFA with gather + - TMA warp: Load B and SFB with multicast + - MMA warp: Perform matrix multiply-accumulate + - Epilogue warps: Apply SwiGLU activation, optional quantization, and store results + + :param a: Input tensor A (MxKx1), will be gathered using token_id_mapping + :type a: cute.Tensor + :param b: Input tensor B (NxKxL), L is the number of experts/groups, weights are interleaved for SwiGLU + :type b: cute.Tensor + :param c: Output tensor C (Mx(N/2)x1), N is halved due to SwiGLU fusion + :type c: cute.Tensor + :param sfa: Scale factor tensor A, will be gathered using token_id_mapping + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param sfc_tensor: Scale factor tensor C for quantized output (None if not quantizing) + :type sfc_tensor: Optional[cute.Tensor] + :param norm_const_tensor: Normalization constant for scale factor generation + (None if not quantizing) + :type norm_const_tensor: Optional[cute.Tensor] + :param tile_idx_to_expert_idx: Mapping from tile index to expert ID, + shape (permuted_m/cta_tile_m,) where cta_tile_m is the CTA tile M size + :type tile_idx_to_expert_idx: cute.Tensor + :param tile_idx_to_mn_limit: Mapping from tile index to M-N dimension limit + for boundary checking, shape (permuted_m/cta_tile_m,) + :type tile_idx_to_mn_limit: cute.Tensor + :param token_id_mapping_tensor: Token ID mapping for gather operation, shape (permuted_m,) + :type token_id_mapping_tensor: cute.Tensor + :param num_non_exiting_tiles: Number of valid tiles to process (valid_m/cta_tile_m), shape (1,) + :type num_non_exiting_tiles: cute.Tensor + :param alpha: Alpha tensor for each group + :type alpha: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(b.shape, self.sf_vec_size) + sfb = cute.make_tensor(sfb.iterator, sfb_layout) + + self.generate_sfc = sfc_tensor is not None and norm_const_tensor is not None + if cutlass.const_expr(self.generate_sfc): + sfc_layout = blockscaled_utils.tile_atom_to_shape_SF(c.shape, self.sf_vec_size) + sfc_tensor = cute.make_tensor(sfc_tensor.iterator, sfc_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mn, + ) + + # For 2CTA blockscaled kernels, SFB needs to be replicated across peer CTAs. # {$nv-internal-release} + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mn_sfb, + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # {$nv-internal-release begin} + # This modifies the layout to handle overlapping 256x(# of scale factors for a single column of B (nNSF)) + # logical blocks for SFB when cta_tile_shape_n=192. + # {$nv-internal-release end} + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + x = tma_tensor_sfb.stride[0][1] + y = cute.ceil_div(tma_tensor_sfb.shape[0][1], 4) + + new_shape = ( + (tma_tensor_sfb.shape[0][0], ((2, 2), y)), + tma_tensor_sfb.shape[1], + tma_tensor_sfb.shape[2], + ) + # Use right multiplication for ScaledBasis (3 * x instead of x * 3) + x_times_3 = 3 * x + new_stride = ( + (tma_tensor_sfb.stride[0][0], ((x, x), x_times_3)), + tma_tensor_sfb.stride[1], + tma_tensor_sfb.stride[2], + ) + tma_tensor_sfb_new_layout = cute.make_layout(new_shape, stride=new_stride) + tma_tensor_sfb = cute.make_tensor(tma_tensor_sfb.iterator, tma_tensor_sfb_new_layout) + + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = (b_copy_size + sfb_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk_c, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid, mn_limit) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 5 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + a_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + b_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + a, + tma_atom_b, + tma_tensor_b, + sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + sfc_tensor, + norm_const_tensor, + tile_idx_to_expert_idx, + tile_idx_to_mn_limit, + token_id_mapping_tensor, + num_non_exiting_tiles, + alpha, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to + partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFC_mnl: Optional[cute.Tensor], + norm_const_tensor: Optional[cute.Tensor], + tile_idx_to_expert_idx: cute.Tensor, + tile_idx_to_mn_limit: cute.Tensor, + token_id_mapping_tensor: cute.Tensor, + num_non_exiting_tiles: cute.Tensor, + alpha: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_b_warp_id: + # cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + # cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # Pipeline Init: Initialize A pipeline for LDGSTS operations + # Producer: 4 warps (warps 4-7) with 128 threads total for LDGSTS operations + # Consumer: MMA warp for consuming A/SFA data + a_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 128 + * cute.size( + cluster_layout_vmnk, mode=[0] + ), # 4 warps * 32 threads per warp = 128 threads + ) + + a_pipeline = PipelineCpAsyncUmma.create( + barrier_storage=storage.a_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=a_pipeline_producer_group, + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + enable_cp_async=(not self.use_2cta_instrs), + ) + + # Pipeline Init: Initialize B pipeline for TMA operations + # Using PipelineTmaUmma for B/SFB since they use TMA load with multicast support + # Producer: TMA B/SFB warp (warp 9) - 1 warp issuing TMA operations + # Consumer: MMA warp for consuming B/SFB data + b_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_b + b_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + b_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.b_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=b_pipeline_producer_group, + consumer_group=b_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, # Total bytes loaded by TMA (B + SFB) + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Pipeline Init: Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Pipeline Init: Tensor memory dealloc barrier init + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + storage.tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=storage.tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((5, self.num_tile_stage), stride=(1, 5)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + # a_full_mcast_mask = None + b_full_mcast_mask = None + # sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_b_mcast or use_2cta_instrs): + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.cta_tile_shape_mnk_sfa, (None, 0, None)), (None, None, None) + ) + + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.mma_tiler_sfb, (0, None, None)), + (None, None, None), + ) + + gToken_ml = cute.local_tile( + token_id_mapping_tensor, cute.slice_(self.cta_tile_shape_mnk, (None, 0, 0)), (None,) + ) + + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler_c, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load B + # + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + self.cta_sync_barrier.arrive_and_wait() + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape) + if mma_tile_coord_m < num_non_exiting_tiles[0]: + tile_info_pipeline.producer_acquire(tile_info_producer_state) + cur_tile_coord = work_tile.tile_idx + expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m] + mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m] + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = expert_idx + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + sInfo[(4, tile_info_producer_state.index)] = mn_limit + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + tile_info_pipeline.producer_acquire(tile_info_producer_state) + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = work_tile.tile_idx[0] + sInfo[(1, tile_info_producer_state.index)] = work_tile.tile_idx[1] + sInfo[(2, tile_info_producer_state.index)] = -1 + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(0) + sInfo[(4, tile_info_producer_state.index)] = -1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized LDGSTS A/SFA warps (warps 4-7) + # These warps use LDGSTS instructions to load A and SFA from global to shared memory + # with gather/permutation capability enabled by token_id_mapping + # + if warp_idx <= self.ldgsts_a_warp_id[-1] and warp_idx >= self.ldgsts_a_warp_id[0]: + # cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Setup LDGSTS copy atoms for A and SFA + # A: 8x LDGSTS.128 per thread with swizzle_128B for A matrix (32 elements per thread) + # SFA: 4x LDGSTS.32 per thread with 512-element block swizzling for scale factor A (4 elements per thread) + # + a_atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + mA_mkl.element_type, + num_bits_per_copy=128, + ) + a_thread_layout = cute.make_layout((16, 8), stride=(8, 1)) + a_value_layout = cute.make_layout((1, 32), stride=(32, 1)) + a_tiled_copy = cute.make_tiled_copy_tv( + a_atom_copy, + a_thread_layout, + a_value_layout, + ) + + sfa_atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=32, + ) + tidx_in_warpgroup = tidx % 128 + + sA_tiled = cute.make_tensor( + sA.iterator, + layout=cute.make_layout( + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2], self.num_ab_stage), + stride=( + self.cta_tile_shape_mnk[2], + 1, + self.cta_tile_shape_mnk[0] * self.cta_tile_shape_mnk[2], + ), + ), + ) + a_thr_copy = a_tiled_copy.get_slice(tidx_in_warpgroup) + tAsA_tiled = a_thr_copy.partition_D(sA_tiled) + + a_token_offset_tensor = cute.make_rmem_tensor( + cute.make_layout((8,)), + cutlass.Int32, + ) + a_predicate_tensor = cute.make_rmem_tensor( + cute.make_layout((8,)), + cutlass.Boolean, + ) + sfa_token_offset_tensor = cute.make_rmem_tensor( + cute.make_layout((1,)), + cutlass.Int32, + ) + sfa_predicate_tensor = cute.make_rmem_tensor( + cute.make_layout((1,)), + cutlass.Boolean, + ) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + a_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((5,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(5, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + # Get tile coord from tile scheduler + # cur_tile_coord = work_tile.tile_idx + + # Load token IDs for gather operation + # For A matrix: each thread loads 8 token offsets (for 8 LDGSTS.128 operations) + # For SFA matrix: each thread loads 1 token offset (for 4 LDGSTS.32 operations) + gToken_ml_tile = gToken_ml[(None, tile_info[0])] + for i in range(8): + token_ml_tile_offset = (tidx_in_warpgroup // 8) + i * 16 + a_token_offset_tensor[i] = gToken_ml_tile[token_ml_tile_offset] + a_predicate_tensor[i] = ( + cutlass.Boolean(1) + if tile_info[0] * self.cta_tile_shape_mnk[0] + token_ml_tile_offset + < tile_info[4] + else cutlass.Boolean(0) + ) + a_token_offset_tensor[i] = ( + a_token_offset_tensor[i] // self.topk + if tile_info[0] * self.cta_tile_shape_mnk[0] + token_ml_tile_offset + < tile_info[4] + else 0 + ) + + token_ml_tile_offset = ( + 8 * (tidx_in_warpgroup // 32) + + 32 * ((tidx_in_warpgroup % 32) // 8) + + (tidx_in_warpgroup % 8) + ) + sfa_token_offset_tensor[0] = gToken_ml_tile[token_ml_tile_offset] // self.topk + sfa_predicate_tensor[0] = ( + cutlass.Boolean(1) + if tile_info[0] * self.cta_tile_shape_mnk[0] + token_ml_tile_offset + < tile_info[4] + else cutlass.Boolean(0) + ) + relative_sfa_token_offset = sfa_token_offset_tensor[0] + + tAgA = gA_mkl[(None, None, 0, None, 0)] + A_gmem_thread_offset = cute.assume((tidx_in_warpgroup % 8) * 32, divby=32) + tAgSFA = gSFA_mkl[(relative_sfa_token_offset, None, 0, None, 0)] + + tAsSFA = sSFA[ + ( + ( + ( + ( + 8 * (tidx_in_warpgroup // 32) + (tidx_in_warpgroup % 8), + (tidx_in_warpgroup % 32) // 8, + ), + None, + ), + None, + ), + None, + None, + None, + ) + ] + + # Peek (try_wait) SCALE buffer empty + a_producer_state.reset_count() + peek_a_empty_status = cutlass.Boolean(1) + if a_producer_state.count < k_tile_cnt: + peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state) + + # + # Load A and SFA with LDGSTS and gather/permutation + # Each K-tile iteration loads one K-tile of A and SFA from GMEM to SMEM + # using LDGSTS instructions with token-based gather addressing + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status) + + tAgA_ktile = tAgA[(None, None, a_producer_state.count)] + tAsA_ktile = tAsA_tiled[(None, None, None, a_producer_state.index)] + + tAgSFA_ktile = tAgSFA[(None, a_producer_state.count)] + tAsSFA_ktile = tAsSFA[ + ( + None, + None, + None, + None, + a_producer_state.index, + ) + ] + + for i in range(8): + # + # Load A matrix: 8x LDGSTS.128 per thread with swizzle_128B + # Each LDGSTS.128 loads 32 elements (128 bits) from GMEM to SMEM + # Global memory address is computed using token offset for gather operation + # Predicate mask guards against invalid token IDs (padding tokens marked as -1) + # + A_gmem_slice_offset = A_gmem_thread_offset + cute.assume( + a_token_offset_tensor[i] * tAgA_ktile.layout[0].stride, divby=32 + ) + A_gmem_slice_offset = cute.assume(A_gmem_slice_offset, divby=32) + tAgA_slice_ptr = tAgA_ktile.iterator + A_gmem_slice_offset + tAgA_slice = cute.make_tensor( + tAgA_slice_ptr, layout=cute.make_layout((32,)) + ) + + tAsA_slice = cute.make_tensor( + tAsA_ktile[(None, i, None)].iterator, layout=cute.make_layout((32,)) + ) + a_predicate_slice = cute.make_rmem_tensor( + cute.make_layout((1,)), cutlass.Boolean + ) + a_predicate_slice[0] = a_predicate_tensor[i] + + cute.copy_atom_call( + a_atom_copy, tAgA_slice, tAsA_slice, pred=a_predicate_slice + ) + + for i in range(4): + # + # Load SFA: 4x LDGSTS.32 per thread with 512-element block swizzling + # Each LDGSTS.32 loads 4 scale factor elements (32 bits) from GMEM to SMEM + # Uses same token offset as A matrix for consistent gather operation + # + swizzled_iterator = (tidx_in_warpgroup % 32) // 8 ^ i + tAgSFA_slice_ptr = tAgSFA_ktile.iterator + 4 * swizzled_iterator + tAgSFA_slice = cute.make_tensor( + tAgSFA_slice_ptr, layout=cute.make_layout((4,)) + ) + + tAsSFA_slice_ptr = tAsSFA_ktile.iterator + 512 * swizzled_iterator + tAsSFA_slice = cute.make_tensor(tAsSFA_slice_ptr, cute.make_layout((4,))) + + cute.copy_atom_call( + sfa_atom_copy, tAgSFA_slice, tAsSFA_slice, pred=sfa_predicate_tensor + ) + + # Signal the completion of async + if cutlass.const_expr(self.use_2cta_instrs): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + a_pipeline.producer_commit(a_producer_state) + + # Peek (try_wait) A buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + a_producer_state.advance() + peek_a_empty_status = cutlass.Boolean(1) + if a_producer_state.count < k_tile_cnt: + peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(5, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A pipeline buffer empty + # + a_pipeline.producer_tail(a_producer_state) + + # + # Specialized TMA B/SFB load warp (warp 9) + # This warp uses TMA instructions to load B and SFB from global to shared memory + # with multicast support to reduce L2 memory traffic + # + if warp_idx == self.tma_b_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # ((atom_v, rest_v), RestK) + # tAgSFA_slice = tAgSFA[(None, mma_tile_coord_mnl[0], None, 0)] + + # Apply SFB slicing hack when cta_tile_shape_n=64 # {$nv-internal-release} + slice_n = mma_tile_coord_mnl[1] + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + slice_n = mma_tile_coord_mnl[1] // 2 + + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[(None, slice_n, None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + b_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if b_producer_state.count < k_tile_cnt: + peek_ab_empty_status = b_pipeline.producer_try_acquire(b_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for B buffer empty + b_pipeline.producer_acquire(b_producer_state, peek_ab_empty_status) + + tBgB_k = tBgB_slice[(None, b_producer_state.count)] + tBgSFB_k = tBgSFB_slice[(None, b_producer_state.count)] + tBsB_pipe = tBsB[(None, b_producer_state.index)] + tBsSFB_pipe = tBsSFB[(None, b_producer_state.index)] + + tma_bar = b_pipeline.producer_get_barrier(b_producer_state) + + # TMA load B + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # TMA load SFB + cute.copy( + tma_atom_sfb, + tBgSFB_k, + tBsSFB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + b_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if b_producer_state.count < k_tile_cnt: + peek_ab_empty_status = b_pipeline.producer_try_acquire(b_producer_state) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + # + # Wait A/B buffer empty + # + b_pipeline.producer_tail(b_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + + # Partition for S2T copy of SFA/SFB + # + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + a_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # Get the first tile info from pipeline (scheduler has filtered out tiles >= num_non_exiting_tiles) + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + a_consumer_state.reset_count() + b_consumer_state.reset_count() + peek_a_full_status = cutlass.Boolean(1) + peek_b_full_status = cutlass.Boolean(1) + if a_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state) + if b_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state) + + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Apply TMEM pointer offset hack when cta_tile_shape_n=192 or + # cta_tile_shape_n=64 # {$nv-internal-release} + + tCtSFB_mma = tCtSFB + if cutlass.const_expr(self.cta_tile_shape_mnk[1] == 192): + # If this is an ODD tile, shift the TMEM start address for + # cta_tile_shape_n=192 case by two words + # (ignores first 64 columns of SFB) + offset = ( + cutlass.Int32(2) if mma_tile_coord_mnl[1] % 2 == 1 else cutlass.Int32(0) + ) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + elif cutlass.const_expr(self.cta_tile_shape_mnk[1] == 64): + # Move in increments of 64 columns of SFB + offset = cutlass.Int32((mma_tile_coord_mnl[1] % 2) * 2) + shifted_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA) + + offset, + dtype=self.sf_dtype, + ) + tCtSFB_mma = cute.make_tensor(shifted_ptr, tCtSFB_layout) + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + # + # Mma mainloop + # + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + for k_tile in cutlass.range(k_tile_cnt): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + + if is_leader_cta: + # Conditionally wait for AB buffer full + a_pipeline.consumer_wait(a_consumer_state, peek_a_full_status) + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + b_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + b_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB_mma[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + a_pipeline.consumer_release(a_consumer_state) + b_pipeline.consumer_release(b_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + a_consumer_state.advance() + b_consumer_state.advance() + peek_a_full_status = cutlass.Boolean(1) + if a_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_a_full_status = a_pipeline.consumer_try_wait(a_consumer_state) + + peek_b_full_status = cutlass.Boolean(1) + if b_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1]: + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc_up, + tTR_rAcc_gate, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc_up.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) + + if cutlass.const_expr(self.generate_sfc): + norm_const = norm_const_tensor[0] + # (EPI_TILE_M, EPI_TILE_N, RestM, RestN, RestL) + gSFC_mnl = cute.local_tile(mSFC_mnl, epi_tile, (None, None, None)) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, RestM, RestN, RestL) + tCgSFC_mnl = thr_copy_t2r.partition_D(gSFC_mnl) + tCgSFC_mnl = cute.filter_zeros(tCgSFC_mnl) + # (T2R, T2R_M, T2R_N) + tCrSFC = cute.make_rmem_tensor( + tCgSFC_mnl[(None, None, None, 0, 0, 0)].layout, self.sf_dtype + ) + tCrSFC_pvscale = cute.make_rmem_tensor_like(tCrSFC, cutlass.Float32) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # Get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + num_prev_subtiles = cutlass.Int32(0) + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Get alpha for current group + # + + expert_idx = mma_tile_coord_mnl[2] + alpha_val = alpha[expert_idx] + + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + 0, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_consumer_state.index)] + + if cutlass.const_expr(self.generate_sfc): + # (T2R, T2R_M, T2R_N, RestM, RestN) + tCgSFC_mn = tCgSFC_mnl[ + ( + None, + None, + None, + None, + None, + 0, + ) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Process accumulator subtiles with SwiGLU fusion and store to global memory + # Each iteration processes a pair of subtiles (up, gate) and computes + # up * silu(gate) + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + + for subtile_idx in cutlass.range(0, subtile_cnt, 2): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn_up = tTR_tAcc[(None, None, None, subtile_idx)] + tTR_tAcc_mn_gate = tTR_tAcc[(None, None, None, subtile_idx + 1)] + + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_up, tTR_rAcc_up) + cute.copy(tiled_copy_t2r, tTR_tAcc_mn_gate, tTR_rAcc_gate) + + acc_vec_up = tTR_rAcc_up.load() + acc_vec_gate = tTR_rAcc_gate.load() + + # + # SwiGLU activation: output = up * silu(gate) + # where silu(x) = x * sigmoid(x) + # up and gate are extracted from interleaved accumulator subtiles + # + tCompute = cute.make_rmem_tensor(acc_vec_gate.shape, self.acc_dtype) + if cutlass.const_expr(self.vectorized_f32): + # SwiGLU Packed Version: uses f32x2 packed operations for better performance + # Computes: output = (alpha * up) * silu(alpha * gate) + # where silu(x) = x * sigmoid(x) = x / (1 + exp(-x)) + LOG2_E = cutlass.Float32(1.4426950408889634) + for i in cutlass.range_constexpr(0, cute.size(tTR_rAcc_up), 2): + acc_vec_up_alpha = cute.arch.mul_packed_f32x2( + (acc_vec_up[i], acc_vec_up[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + ) + acc_vec_gate_alpha = cute.arch.mul_packed_f32x2( + (acc_vec_gate[i], acc_vec_gate[i + 1]), + (cutlass.Float32(alpha_val), cutlass.Float32(alpha_val)), + ) + tCompute_log2e = cute.arch.mul_packed_f32x2( + (acc_vec_gate_alpha[0], acc_vec_gate_alpha[1]), (-LOG2_E, -LOG2_E) + ) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.add_packed_f32x2( + ( + cute.math.exp2(tCompute_log2e[0], fastmath=True), + cute.math.exp2(tCompute_log2e[1], fastmath=True), + ), + (1.0, 1.0), + ) + tCompute[i] = cute.arch.rcp_approx(tCompute[i]) + tCompute[i + 1] = cute.arch.rcp_approx(tCompute[i + 1]) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_gate_alpha[0], acc_vec_gate_alpha[1]), + ) + ( + tCompute[i], + tCompute[i + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[i], tCompute[i + 1]), + (acc_vec_up_alpha[0], acc_vec_up_alpha[1]), + ) + else: + # SwiGLU Unpacked Version: scalar operations + # Computes: output = (alpha * up) * silu(alpha * gate) + for i in cutlass.range_constexpr(cute.size(tTR_rAcc_up)): + acc_vec_up_alpha = acc_vec_up[i] * cutlass.Float32(alpha_val) + acc_vec_gate_alpha = acc_vec_gate[i] * cutlass.Float32(alpha_val) + tCompute[i] = acc_vec_up_alpha * silu_f32( + acc_vec_gate_alpha, fastmath=True + ) + + if cutlass.const_expr(self.generate_sfc): + # + # Quantization path for Float4E2M1FN output: + # 1. Compute per-vector absolute max from SwiGLU result + # 2. Generate scale factor C (SFC) based on max values + # 3. Store SFC to global memory + # 4. Quantize output by scaling with reciprocal of SFC + # + # Assume subtile partitioned always happens on n dimension + sfc_subtile_idx_mn = ( + tile_info[0] * self.epi_tile_cnt[0], + tile_info[1] * self.epi_tile_cnt[1] + subtile_idx // 2, + ) + tCgSFC = tCgSFC_mn[ + ( + None, + None, + None, + *sfc_subtile_idx_mn, + ) + ] + + # + # Get absolute max across a vector and Compute SFC + # + tTR_rAcc_frg = cute.logical_divide( + tCompute, cute.make_layout(self.sf_vec_size) + ) + acc_frg = tTR_rAcc_frg.load() + acc_frg = epilogue_op(acc_frg) + + # Apply element-wise absolute value using math.absf (supports vectors) + abs_acc_frg_ir = math.absf(acc_frg.ir_value()) + abs_acc_frg = type(acc_frg)(abs_acc_frg_ir, acc_frg.shape, acc_frg.dtype) + + if cutlass.const_expr(self.vectorized_f32): + for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]): + tCrSFC_pvscale[vi] = abs_acc_frg[None, vi].reduce( + cute.ReductionOp.MAX, + cutlass.Float32(0.0), + 0, # Use 0.0 as init for abs values + ) + for vi in cutlass.range_constexpr(0, abs_acc_frg.shape[1], 2): + tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1] = ( + cute.arch.mul_packed_f32x2( + (tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1]), + ( + self.get_dtype_rcp_limits(self.c_dtype), + self.get_dtype_rcp_limits(self.c_dtype), + ), + ) + ) + tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1] = ( + cute.arch.mul_packed_f32x2( + (tCrSFC_pvscale[vi], tCrSFC_pvscale[vi + 1]), + (norm_const, norm_const), + ) + ) + else: + for vi in cutlass.range_constexpr(abs_acc_frg.shape[1]): + tCrSFC_pvscale[vi] = ( + abs_acc_frg[None, vi].reduce( + cute.ReductionOp.MAX, + cutlass.Float32(0.0), + 0, # Use 0.0 as init for abs values + ) + * self.get_dtype_rcp_limits(self.c_dtype) + * norm_const + ) + + # TODO: need to add f32x2 -> f8x2 conversion + tCrSFC.store(tCrSFC_pvscale.load().to(self.sf_dtype)) + + # + # Store SFC to global memory + # + # TODO: Need to think about predicate on it + # if cute.elem_less(): + cute.autovec_copy(tCrSFC, tCgSFC) + + # + # Compute quantized output values and convert to C type + # + # TODO: need to add f8x2 -> f32x2 conversion + tCrSFC_qpvscale_up = tCrSFC.load().to(cutlass.Float32) + fp32_max = cutlass.Float32(3.40282346638528859812e38) + if cutlass.const_expr(self.vectorized_f32): + for vi in cutlass.range_constexpr(0, cute.size(tCrSFC), 2): + acc_scale = cute.arch.mul_packed_f32x2( + ( + cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi]), + cute.arch.rcp_approx(tCrSFC_qpvscale_up[vi + 1]), + ), + (norm_const, norm_const), + ) + acc_scale_min0 = fmin(acc_scale[0], fp32_max, nan=True) + acc_scale_min1 = fmin(acc_scale[1], fp32_max, nan=True) + + vec0 = tTR_rAcc_frg[None, vi] + vec1 = tTR_rAcc_frg[None, vi + 1] + for ei in cutlass.range_constexpr(self.sf_vec_size): + vec0[ei], vec1[ei] = cute.arch.mul_packed_f32x2( + (vec0[ei], vec1[ei]), + (acc_scale_min0, acc_scale_min1), + ) + else: + for vi in cutlass.range_constexpr(cute.size(tCrSFC)): + # TODO:Need to add E8M0 rcp approximation + acc_scale = norm_const * cute.arch.rcp_approx( + tCrSFC_qpvscale_up[vi] + ) + acc_scale = fmin(acc_scale, fp32_max, nan=True) + + vec = tTR_rAcc_frg[None, vi] + for ei in cutlass.range_constexpr(self.sf_vec_size): + vec[ei] = vec[ei] * acc_scale + + acc_vec = tiled_copy_r2s.retile(tCompute).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + else: + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tCompute).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = (num_prev_subtiles + subtile_idx // 2) % self.num_c_stage + + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx // 2)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory + (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc_up, tTR_rAcc_gate) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc_up: The partitioned accumulator tensor for acc up + - tTR_rAcc_gate: The partitioned accumulator tensor for acc gate + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc_up = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N) + tTR_rAcc_gate = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc_up, tTR_rAcc_gate + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register + array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Vector size of scale factor. + :type sf_vec_size: int + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + # cute.printf("num_smem_capacity: {}, occupancy: {}, " + # "mbar_helpers_bytes: {}, c_bytes: {}", + # num_smem_capacity, occupancy, mbar_helpers_bytes, c_bytes) + # cute.printf("ab_bytes_per_stage: {}", ab_bytes_per_stage) + num_ab_stage = ( + num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def get_dtype_rcp_limits(dtype: Type[cutlass.Numeric]) -> float: + """ + Calculates the reciprocal of the maximum absolute value for a given data type. + + :param dtype: Data type + :type dtype: Type[cutlass.Numeric] + + :return: An float representing the reciprocal of the maximum absolute value + :rtype: float + """ + if dtype == cutlass.Float4E2M1FN: + return 1 / 6.0 + if dtype == cutlass.Float8E4M3FN: + return 1 / 448.0 + if dtype == cutlass.Float8E5M2: + return 1 / 128.0 + return 1.0 + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + if acc_dtype not in {cutlass.Float32}: + is_valid = False + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + cutlass.Float4E2M1FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if layouts and dtypes are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + # TODO: Currently we don't support m major output for Float4E2M1FN, + # Need to support it in the future. + if c_dtype is cutlass.Float4E2M1FN and c_major == "m": + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m_aligned: int, + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m_aligned: The alignment requirement for group M dimension (default: 128) + :type m_aligned: int + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + # Needs to have even iterations with Epi Tile N 64 for swiGeLU fusion + if mma_tiler_mn[1] not in (128, 256): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + cluster_tiler_m = (cluster_shape_mn[0] // (2 if use_2cta_instrs else 1)) * mma_tiler_mn[0] + + # Skip invalid cluster tiler shape since contiguous layout can't handle oob access + # The contiguous layout means the aligned data is stored in a contiguous manner. + # It can't handle runtime oob when alignment is not align with the tile_M, + # since the problem shape of TMA store can't be changed at runtime. + if cluster_tiler_m not in [64, 128, 256]: + is_valid = False + + # Check if m_aligned is a multiple of cluster_tiler_m + # This ensures that each group's M dimension (which is a multiple of m_aligned) + # won't be split across tiles, preventing a single tile from loading data + # from multiple groups (which would access wrong B matrix data) + if m_aligned % mma_tiler_mn[0] != 0: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, # noqa: E741 + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, # noqa: E741 + a_major: str, + b_major: str, + c_major: str, + m_aligned: int, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + :param m_aligned: The alignment requirement for group M dimension (default: 128) + :type m_aligned: int + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, acc_dtype, c_dtype + ): + can_implement = False + + # Skip unsupported layouts + if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + + use_2cta_instrs = mma_tiler_mn[0] == 256 + # Skip invalid mma tile shape and cluster shape + if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, m_aligned + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not BlockScaledContiguousGatherGroupedGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + @cute.jit + def wrapper( + self, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + a_sf_ptr: cute.Pointer, + b_sf_ptr: cute.Pointer, + c_ptr: cute.Pointer, + c_sf_ptr: cute.Pointer, + alpha_ptr: cute.Pointer, + tile_idx_to_group_idx_ptr: cute.Pointer, + tile_idx_to_mn_limit_ptr: cute.Pointer, + token_id_mapping_ptr: cute.Pointer, + num_non_exiting_tiles_ptr: cute.Pointer, + global_sf_ptr: cute.Pointer, + orig_m: int, + m: int, + n: int, + k: int, + l: int, # noqa: E741 + tile_size: cutlass.Constexpr, + scaling_vector_size: cutlass.Constexpr, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + scale_k = k // scaling_vector_size + interm_size = n // 2 + num_tiles = m // tile_size + a = cute.make_tensor( + a_ptr, layout=cute.make_ordered_layout((orig_m, k, 1), order=(1, 0, 2)) + ) + b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2))) + a_sf = cute.make_tensor( + a_sf_ptr, layout=cute.make_ordered_layout((orig_m, scale_k, 1), order=(1, 0, 2)) + ) + b_sf = cute.make_tensor( + b_sf_ptr, + layout=cute.make_ordered_layout( + (32, 4, n // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5) + ), + ) + c = cute.make_tensor( + c_ptr, layout=cute.make_ordered_layout((m, interm_size, 1), order=(1, 0, 2)) + ) + c_sf = cute.make_tensor( + c_sf_ptr, + layout=cute.make_ordered_layout( + (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l), + order=(2, 1, 4, 0, 3, 5), + ), + ) + alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,))) + + tile_idx_to_group_idx = cute.make_tensor( + tile_idx_to_group_idx_ptr, layout=cute.make_layout((num_tiles,)) + ) + tile_idx_to_mn_limit = cute.make_tensor( + tile_idx_to_mn_limit_ptr, layout=cute.make_layout((num_tiles,)) + ) + token_id_mapping = cute.make_tensor(token_id_mapping_ptr, layout=cute.make_layout((m,))) + num_non_exiting_tiles = cute.make_tensor( + num_non_exiting_tiles_ptr, layout=cute.make_layout((1,)) + ) + global_sf = cute.make_tensor(global_sf_ptr, layout=cute.make_layout((1,))) + + return self( + a, + b, + c, + a_sf, + b_sf, + c_sf, + global_sf, + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + token_id_mapping, + num_non_exiting_tiles, + alpha, + max_active_clusters=max_active_clusters, + stream=stream, + epilogue_op=epilogue_op, + ) diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py index 5877a31132f..009eb2f7304 100644 --- a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/custom_pipeline.py @@ -48,8 +48,8 @@ import cutlass.cute as cute from cutlass.cutlass_dsl import Boolean, if_generate -from cutlass.pipeline import (CooperativeGroup, PipelineAsync, PipelineOp, - PipelineState) +from cutlass.pipeline import (Agent, CooperativeGroup, PipelineAsync, + PipelineOp, PipelineState, agent_sync) def pipeline_init_wait(cta_layout_vmnk: Optional[cute.Layout] = None): @@ -374,3 +374,153 @@ def then_body(): self.producer_acquire(state) if_generate(is_leader_cta, then_body) + + +@dataclass(frozen=True) +class PipelineCpAsyncUmma(PipelineAsync): + """ + PipelineCpAsyncUmma is used for LDGSTS (CpAsync) producers and UMMA consumers. + + This pipeline is specifically designed for scenarios where: + - Producers use LDGSTS instructions (cp.async) to load data from global to shared memory + - Consumers are UMMA warps that perform MMA operations using the loaded data + + Key differences from PipelineAsyncUmma: + - Suitable for gather/permutation operations during load + - Used in this kernel for A and SFA matrices with token-based gather addressing + """ + + cta_group: cute.nvgpu.tcgen05.CtaGroup + + @staticmethod + def _compute_leading_cta_rank(cta_v_size): + """ + Computes the leading CTA rank. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster()) + return cta_rank_in_cluster // cta_v_size * cta_v_size + + @staticmethod + def _compute_is_leader_cta(cta_layout_vmnk: cute.Layout): + """ + Computes leader threadblocks for 2CTA kernels. For 1CTA, all threadblocks are leaders. + """ + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_vmnk = ( + bidx % cute.size(cta_layout_vmnk, mode=[0]), + bidx // cute.size(cta_layout_vmnk, mode=[0]), + bidy, + None, + ) + return mma_coord_vmnk[0] == 0 + + @staticmethod + def _compute_peer_cta_mask(cta_layout_vmnk: cute.Layout): + """ + Computes a mask for signaling arrivals to multicasting threadblocks. + """ + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster()) + cta_in_cluster_coord_vmnk = cta_layout_vmnk.get_flat_coord( + cta_rank_in_cluster) + mask_self = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, cta_in_cluster_coord_vmnk, mcast_mode=0) + block_in_cluster_coord_vmnk_peer = ( + cta_in_cluster_coord_vmnk[0] ^ 1, + *cta_in_cluster_coord_vmnk[1:], + ) + mask_peer = cute.nvgpu.cpasync.create_tma_multicast_mask( + cta_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=0) + return mask_self | mask_peer + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + defer_sync: bool = False, + enable_cp_async: bool = False, + ): + """Creates and initializes a new PipelineCpAsyncUmma instance. + + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: int + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param barrier_storage: Pointer to the shared memory address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer, optional + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout, optional + :param defer_sync: Whether to defer the sync + :type defer_sync: bool, optional + :param enable_cp_async: Whether to enable cp.async instructions + :type enable_cp_async: bool, optional + :raises ValueError: If barrier_storage is not a cute.Pointer instance + :return: A new PipelineCpAsyncUmma instance configured with the provided parameters + :rtype: PipelineCpAsyncUmma + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.AsyncLoad if enable_cp_async else PipelineOp.AsyncThread + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), + num_stages, + producer, + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, + consumer) + + cta_v_size = cute.size(cta_layout_vmnk, + mode=[0]) if cta_layout_vmnk is not None else 1 + cta_group = (cute.nvgpu.tcgen05.CtaGroup.ONE if cta_layout_vmnk is None + or cute.size(cta_layout_vmnk, mode=[0]) == 1 else + cute.nvgpu.tcgen05.CtaGroup.TWO) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1: + # No mcast mask if we're not using 2CTA tcgen05 MMA + producer_mask = None + consumer_mask = None + else: + # If we're using 2CTA UMMAs, producer will arrive the mbar on leading CTA + # We need to get the target cta_rank + producer_mask = PipelineCpAsyncUmma._compute_leading_cta_rank( + cta_v_size) + # consumer needs to get the mask to signal + consumer_mask = PipelineCpAsyncUmma._compute_peer_cta_mask( + cta_layout_vmnk) + + if not defer_sync: + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + agent_sync(Agent.ThreadBlock) + else: + agent_sync(Agent.ThreadBlockCluster, is_relaxed=True) + + return PipelineCpAsyncUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + cta_group, + ) + + def consumer_release(self, state: PipelineState): + """ + UMMA consumer release buffer empty, cta_group needs to be provided. + """ + self.sync_object_empty.arrive(state.index, self.consumer_mask, + self.cta_group) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py index a087a4c87a8..0ecd3e3e856 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py @@ -273,22 +273,16 @@ def run_moe_nvfp4( local_num_experts=self.expert_size_per_partition, tile_tokens_dim=tile_size, ) - x, x_sf = torch.ops.trtllm.moe_permute( - input=x.view(torch.float4_e2m1fn_x2), - input_sf=x_sf, - tile_idx_to_mn_limit=tile_idx_to_mn_limit, - permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, - num_non_exiting_tiles=num_non_exiting_tiles, - tile_tokens_dim=tile_size, - top_k=self.routing_method.experts_per_token, - ) - x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell( + + x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( input=x.view(torch.float4_e2m1fn_x2), weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2), input_scale=x_sf.view(torch.uint8), weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8), alpha=self.quant_scales.fc1_global, tile_idx_to_group_idx=tile_idx_to_expert_idx, + tile_idx_to_mn_limit=tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx, num_non_exiting_tiles=num_non_exiting_tiles, global_sf=self.fc2_input_scale, num_experts=self.num_slots, diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 5f77a4c7a1b..dac655b1c33 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -291,6 +291,15 @@ def fp4_scale_infer_shape(input_shapes: List[List[int]]): return scale_shape * 2 +def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]): + """Calculate the dimensions of the fp4 scale tensor. + """ + out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0], + sf_vec_size=16, + is_swizzled_layout=False) + return scale_shape * 2 + + _enable_piecewise_cuda_graph = True diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py index 4faec5d6f13..f146573ff03 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_moe.py @@ -1,7 +1,11 @@ import pytest import torch +from utils.util import check_accuracy -from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import GroupedGemmInputsHelper +from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import ( + GatherGroupedGemmInputsHelper, + GroupedGemmInputsHelper, +) from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import cute_dsl_nvfp4_grouped_gemm_ref from tensorrt_llm._torch.modules.fused_moe.quantization import interleave_linear_and_gate from tensorrt_llm._torch.utils import swizzle_sf, unswizzle_sf @@ -707,3 +711,204 @@ def test_nvfp4_grouped_gemm_swiglu_blackwell( c_sf[:num_sf_elements] == c_sf_ref[:num_sf_elements] ).sum().item() / num_sf_elements assert match_ratio > 0.95 + + +@pytest.mark.skipif( + get_sm_version() not in (100, 103), + reason="This test is only supported on SM 100 and SM 103 GPUs", +) +@pytest.mark.parametrize("tile_size", [128, 256]) +@pytest.mark.parametrize("ep_size", [1, 8, 32]) +@pytest.mark.parametrize("top_k", [1, 2, 8]) +@pytest.mark.parametrize("num_tokens", [128, 515, 1024, 8192]) +def test_nvfp4_gather_grouped_gemm_swiglu_blackwell( + num_tokens: int, top_k: int, ep_size: int, tile_size: int +): + """Test gather-based grouped GEMM with SwiGLU fusion. + + This test validates the gather kernel which: + 1. Uses LDGSTS for A/SFA loading with permuted_idx_to_expanded_idx + 2. Performs GEMM with interleaved weights + 3. Applies SwiGLU activation fusion + 4. Quantizes output to FP4 with scale factor generation + """ + sf_vec_size = 16 + hidden_size = 4096 + interm_size = 8192 + num_experts = 256 + num_local_experts = num_experts // ep_size + + # Generate routing information + routing_logits = torch.randn(num_tokens, num_experts, device="cuda") + _, token_selected_experts = routing_logits.topk(top_k, dim=-1) + token_selected_experts = token_selected_experts.to(torch.int32) + num_tokens_per_expert = torch.bincount(token_selected_experts.flatten(), minlength=num_experts) + num_tokens_per_expert = num_tokens_per_expert[:num_local_experts] + # Ensure at least one valid token + if num_tokens_per_expert.sum().item() == 0: + num_tokens_per_expert[0] = 1 + num_tiles_per_expert = (num_tokens_per_expert + tile_size - 1) // tile_size + num_tokens_per_expert = num_tokens_per_expert.cpu() + num_tiles_per_expert = num_tiles_per_expert.cpu() + num_valid_tiles = num_tiles_per_expert.sum().item() + num_valid_permuted_tokens = num_valid_tiles * tile_size + + # Create helper + helper = GatherGroupedGemmInputsHelper(num_experts, top_k, num_local_experts, 0, tile_size) + max_num_tiles = helper.get_max_num_tiles(num_tokens) + max_num_permuted_tokens = helper.get_max_num_permuted_tokens(num_tokens) + assert 0 <= num_valid_tiles <= max_num_tiles + assert 0 <= num_valid_permuted_tokens <= max_num_permuted_tokens + + # Generate tile metadata + num_non_exiting_tiles = torch.tensor([num_valid_tiles], dtype=torch.int32, device="cuda") + tile_idx_to_group_idx = torch.empty(max_num_tiles, dtype=torch.int32) + tile_idx_to_mn_limit = torch.empty(max_num_tiles, dtype=torch.int32) + tile_idx_to_group_idx.fill_(int(-2e9)) + tile_idx_to_mn_limit.fill_(int(-2e9)) + + tile_idx_to_group_idx_list = helper.generate_tile_idx_to_group_idx( + num_tokens_per_expert.tolist() + ) + tile_idx_to_mn_limit_list = helper.generate_tile_idx_to_mn_limit(num_tokens_per_expert.tolist()) + + for idx, (group_idx, mn_limit) in enumerate( + zip(tile_idx_to_group_idx_list, tile_idx_to_mn_limit_list) + ): + tile_idx_to_group_idx[idx] = group_idx + tile_idx_to_mn_limit[idx] = mn_limit + + tile_idx_to_group_idx = tile_idx_to_group_idx.cuda() + tile_idx_to_mn_limit = tile_idx_to_mn_limit.cuda() + + # Generate permuted_idx_to_expanded_idx for gather operation + permuted_idx_to_expanded_idx_list = helper.generate_permuted_idx_to_expanded_idx( + num_tokens, num_tokens_per_expert.tolist(), max_num_permuted_tokens + ) + permuted_idx_to_expanded_idx = torch.tensor( + permuted_idx_to_expanded_idx_list, dtype=torch.int32, device="cuda" + ) + assert permuted_idx_to_expanded_idx.size(0) == max_num_permuted_tokens + + # Create input tensors (original size, not permuted) + a = torch.randint(-5, 5, (num_tokens, hidden_size), dtype=torch.int32, device="cuda").to( + torch.bfloat16 + ) + b = torch.randint( + -5, + 5, + (num_local_experts, interm_size * 2, hidden_size), + dtype=torch.int32, + device="cuda", + ).to(torch.bfloat16) + + # Quantize inputs to FP4 + a_global_sf = a.abs().max().float() / (448 * 6) + b_global_sf = b.abs().amax(dim=(1, 2)).float() / (448 * 6) + a, a_sf = torch.ops.trtllm.fp4_quantize(a, 1 / a_global_sf, sf_vec_size, False) + a = a.view(torch.float4_e2m1fn_x2) + a_sf_unswizzled = unswizzle_sf(a_sf, (num_tokens + 127) // 128 * 128, hidden_size)[:num_tokens] + b, b_sf = torch.ops.trtllm.fp4_quantize(b, 1 / b_global_sf, sf_vec_size, False) + b = b.view(torch.float4_e2m1fn_x2) + b_sf = b_sf.view(num_local_experts, interm_size * 2, hidden_size // sf_vec_size) + alpha = a_global_sf * b_global_sf + + # Interleave weights for SwiGLU + b_interleaved = interleave_linear_and_gate(b.view(torch.uint8), group_size=64, dim=1).view( + torch.float4_e2m1fn_x2 + ) + b_sf_unswizzled = unswizzle_sf(b_sf, interm_size * 2, hidden_size).view( + num_local_experts, interm_size * 2, hidden_size // sf_vec_size + ) + b_sf_unswizzled_interleaved = interleave_linear_and_gate(b_sf_unswizzled, group_size=64, dim=1) + b_sf_interleaved = swizzle_sf(b_sf_unswizzled_interleaved, interm_size * 2, hidden_size).view( + num_local_experts, interm_size * 2, hidden_size // sf_vec_size + ) + + # Compute reference: manually gather, compute GEMM, apply SwiGLU, then quantize + a_gathered = torch.empty( + max_num_permuted_tokens, hidden_size // 2, dtype=a.dtype, device=a.device + ) + a_sf_gathered = torch.empty( + max_num_permuted_tokens, hidden_size // sf_vec_size, dtype=a_sf.dtype, device=a_sf.device + ) + for i in range(num_valid_permuted_tokens): + expanded_idx = permuted_idx_to_expanded_idx[i].item() + if expanded_idx != helper.pad_val: + token_id = expanded_idx // top_k + a_gathered[i] = a[token_id] + a_sf_gathered[i] = a_sf_unswizzled[token_id] + + # Swizzle a_sf_gathered for reference GEMM + a_sf_gathered_swizzled = swizzle_sf( + a_sf_gathered.view(max_num_permuted_tokens, hidden_size // sf_vec_size), + max_num_permuted_tokens, + hidden_size, + ) + + c_ref = cute_dsl_nvfp4_grouped_gemm_ref( + a_gathered, + b, + a_sf_gathered_swizzled, + b_sf, + alpha, + tile_idx_to_group_idx, + num_non_exiting_tiles, + tile_size=tile_size, + output_dtype=torch.bfloat16, + scaling_vector_size=sf_vec_size, + ) + c_ref = swiglu_ref(c_ref) + global_sf = c_ref[:num_valid_permuted_tokens].abs().max().float() / (448 * 6) + c_ref, c_sf_ref = torch.ops.trtllm.fp4_quantize(c_ref, 1 / global_sf, sf_vec_size, False) + + # Call gather kernel + c, c_sf = torch.ops.trtllm.cute_dsl_nvfp4_gather_grouped_gemm_swiglu_blackwell( + a, + b_interleaved, + a_sf_unswizzled, + b_sf_interleaved, + alpha, + tile_idx_to_group_idx, + tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, + num_non_exiting_tiles, + torch.tensor([1 / global_sf], dtype=torch.float32, device="cuda"), + num_experts=num_experts, + top_k=top_k, + num_local_experts=num_local_experts, + local_expert_offset=0, + tile_size=tile_size, + scaling_vector_size=sf_vec_size, + ) + + # Verify output (only compare valid tokens, skip padding tokens where permuted_idx_to_expanded_idx == -1) + # Create mask for valid tokens + valid_token_mask = torch.zeros(num_valid_permuted_tokens, dtype=torch.bool, device="cuda") + for i in range(num_valid_permuted_tokens): + if permuted_idx_to_expanded_idx[i].item() != helper.pad_val: + valid_token_mask[i] = True + + num_valid_tokens = valid_token_mask.sum().item() + if num_valid_tokens > 0: + # Compare output values only for valid tokens + c_valid = c[:num_valid_permuted_tokens].view(torch.uint8)[valid_token_mask] + c_ref_valid = c_ref[:num_valid_permuted_tokens][valid_token_mask] + check_accuracy(c_valid, c_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95) + + c_sf_unswizzled = unswizzle_sf(c_sf, max_num_permuted_tokens, interm_size, sf_vec_size) + c_sf_ref_unswizzled = unswizzle_sf( + c_sf_ref, max_num_permuted_tokens, interm_size, sf_vec_size + ) + + # Compare scale factors only for valid tokens + c_sf_valid = [] + c_sf_ref_valid = [] + for i in range(num_valid_permuted_tokens): + if permuted_idx_to_expanded_idx[i].item() != helper.pad_val: + c_sf_valid.append(c_sf_unswizzled[i]) + c_sf_ref_valid.append(c_sf_ref_unswizzled[i]) + + c_sf_valid = torch.cat(c_sf_valid) + c_sf_ref_valid = torch.cat(c_sf_ref_valid) + check_accuracy(c_sf_valid, c_sf_ref_valid, atol=1e-4, rtol=1e-4, percent=0.95)