diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py index 0e574326ba91..59305d517b9b 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py @@ -180,6 +180,10 @@ def compute_num_stages( # that is not fully captured by the simple stage_size model above. if is_persistent and (lhs_dtype == FP32 or rhs_dtype == FP32): smem_capacity -= 32 * 1024 + if is_persistent and not has_native_mxfp and epilogue_reduction_n > 1: + # Hopper fused reductions materialize an additional reduced-N output + # tile in smem. + smem_capacity -= int(block_m * acc_block_n * out_itemsize) smem_capacity = max(smem_capacity, 0) max_stages = 5 if rhs_dtype == FP4 else 4 # maybe 5 everywhere; just haven't tested num_stages = min(smem_capacity // int(stage_size), max_stages) diff --git a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py index 0bdff7b3ed8e..ba68c071e28e 100644 --- a/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py +++ b/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_value_shuffled.py @@ -104,13 +104,11 @@ def swizzle_data(self, data: torch.Tensor) -> torch.Tensor: Target layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] This matches the baseline TMA block shape [block_n, packed_block_k] after swapping. """ - if data.ndim == 2: - data = data.unsqueeze(0) - if data.ndim != 3: - raise ValueError(f"Expected 2D or 3D canonical data, got {data.ndim}D") - data = self._canonical_to_physical(data) - E, K_packed, N = data.shape + leading_shape = data.shape[:-2] + E = math.prod(leading_shape) + K_packed, N = data.shape[-2:] + data = data.reshape(E, K_packed, N) tile_k_packed, tile_n, padded_K_packed, padded_N, num_tiles_k, num_tiles_n = \ self._compute_params(E, K_packed, N) @@ -139,6 +137,7 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor: Input layout: [E, num_tiles_k, num_tiles_n, tile_n, tile_k_packed] """ E = data.shape[0] + leading_shape = self.shape[:-2] # Recover original shape from self.shape (the logical shape passed to convert_layout) orig_K_packed = self.shape[-2] // 2 if self.is_fp4 else self.shape[-2] orig_N = self.shape[-1] @@ -159,4 +158,6 @@ def unswizzle_data(self, data: torch.Tensor) -> torch.Tensor: # Trim padding back to original shape data = data[:, :orig_K_packed, :orig_N].contiguous() data = self._physical_to_canonical(data) - return data if len(self.shape) == 3 else data.squeeze(0) + if not leading_shape: + return data.squeeze(0) + return data.reshape(*leading_shape, data.shape[-2], data.shape[-1])