diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 764d9df1115f..9205e5bdcd90 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1682,6 +1682,15 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) { auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames()); StringAttr dimM = mfmaOutDims[0]; StringAttr dimN = mfmaOutDims[1]; + unsigned destIdxInBases = isMfma32 ? 3 : 4; + // The column swap below exchanges N-dim basis bit 2 with bit + // `destIdxInBases`. The target bit only exists when the N dimension has at + // least `1 << (destIdxInBases + 1)` columns: 16 for mfma32x32 and 32 for + // mfma16x16. Smaller N dimensions produce fewer basis vectors, so the swap + // would otherwise index past the end of `dimNBases`. + if (mfmaLL.getOutDimSizeLog2(dimN) <= destIdxInBases) + return {}; + auto swapLL = LinearLayout::empty(); // The rows are kept as is with an identity linear layout. swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM); @@ -1777,7 +1786,6 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) { the original mfma16 LL. clang-format on */ - auto destIdxInBases = isMfma32 ? 3 : 4; std::vector> dimNBases(mfmaLL.getOutDimSizeLog2(dimN)); std::generate(dimNBases.begin(), dimNBases.end(), [i = 0]() mutable { return std::vector{1 << i++}; }); diff --git a/test/TritonGPU/amd/amd-optimize-epilogue.mlir b/test/TritonGPU/amd/amd-optimize-epilogue.mlir index bc82a30c876e..9fb06a1f2176 100644 --- a/test/TritonGPU/amd/amd-optimize-epilogue.mlir +++ b/test/TritonGPU/amd/amd-optimize-epilogue.mlir @@ -21,6 +21,24 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // ----- +// CHECK-LABEL: store_dword_mfma32_small_n +// CHECK-NOT: tensor<32x8x!tt.ptr, #linear> +// CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x8x!tt.ptr, #mma> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> +#mma = #ttg.amd_mfma<{version = 4, warpsPerCTA = [4, 1], instrShape = [32, 32, 16], isTransposed = true}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @store_dword_mfma32_small_n(%arg0: !tt.ptr) { + %cst = arith.constant dense<0.000000e+00> : tensor<32x8xf32, #mma> + %0 = ttg.convert_layout %cst : tensor<32x8xf32, #mma> -> tensor<32x8xf32, #blocked> + %1 = arith.truncf %0 : tensor<32x8xf32, #blocked> to tensor<32x8xf16, #blocked> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<32x8x!tt.ptr, #blocked> + tt.store %2, %1 : tensor<32x8x!tt.ptr, #blocked> + tt.return + } +} + +// ----- + // CHECK-LABEL: two_ops_in_chain // CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked> // CHECK: tt.store %{{.*}}, %{{.*}} : tensor<32x32x!tt.ptr, #mma>