From 9db0355498efa07de5da93d2974d1efd6e3b27b6 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Fri, 10 Apr 2026 15:57:22 -0700 Subject: [PATCH 1/2] [KERNELS] fix hopper smem heuristic --- .../matmul_details/opt_flags_details/opt_flags_nvidia.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py index 0e574326ba91..59305d517b9b 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py @@ -180,6 +180,10 @@ def compute_num_stages( # that is not fully captured by the simple stage_size model above. if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32): smem_capacity -= 32 * 1024 + if is_persistent and not has_native_mxfp and epilogue_reduction_n > 1: + # Hopper fused reductions materialize an additional reduced-N output + # tile in smem. + smem_capacity -= int(block_m * acc_block_n * out_itemsize) smem_capacity = max(smem_capacity, 0) max_stages = 5 if rhs_dtype == FP4 else 4 # maybe 5 everywhere; just haven't tested num_stages = min(smem_capacity // int(stage_size), max_stages) From 9667825e3bdec4ba059b4634dd7e45066a7fc547 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Fri, 10 Apr 2026 15:59:17 -0700 Subject: [PATCH 2/2] fix leading shape --- .../layout_details/blackwell_value_shuffled.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py index 0bdff7b3ed8e..ba68c071e28e 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py @@ -104,13 +104,11 @@ def swizzle_data(self, data: torch.Tensor) -> torch.Tensor: Target layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] This matches the baseline TMA block shape [block_n, packed_block_k] after swapping. """ - if data.ndim == 2: - data = data.unsqueeze(0) - if data.ndim != 3: - raise ValueError(f"Expected 2D or 3D canonical data, got {data.ndim}D") - data = self._canonical_to_physical(data) - E, K_packed, N = data.shape + leading_shape = data.shape[:-2] + E = math.prod(leading_shape) + K_packed, N = data.shape[-2:] + data = data.reshape(E, K_packed, N) tile_k_packed, tile_n, padded_K_packed, padded_N, num_tiles_k, num_tiles_n = \ self._compute_params(E, K_packed, N) @@ -139,6 +137,7 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor: Input layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] """ E = data.shape[0] + leading_shape = self.shape[:-2] # Recover original shape from self.shape (the logical shape passed to convert_layout) orig_K_packed = self.shape[-2] // 2 if self.is_fp4 else self.shape[-2] orig_N = self.shape[-1] @@ -159,4 +158,6 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor: # Trim padding back to original shape data = data[:, :orig_K_packed, :orig_N].contiguous() data = self._physical_to_canonical(data) - return data if len(self.shape) == 3 else data.squeeze(0) + if not leading_shape: + return data.squeeze(0) + return data.reshape(*leading_shape, data.shape[-2], data.shape[-1])