From a13b3249e951a965fafcf2e3dc14ab52a14a8009 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Sun, 23 Jun 2024 22:47:03 -0700 Subject: [PATCH 1/2] [BACKEND] Fix regression in pipeliner pre-checks. During some previous refactoring we changed the logic and started pipeling cases that had incompatible shared encoding. This was missed because one of the lit test had not been updated :( --- .../Pipeliner/MatmulLoopPipeline.cpp | 18 +++++++++++++----- test/TritonGPU/loop-pipeline.mlir | 13 ++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 6752881b3350..f45144ebb664 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -229,8 +229,9 @@ static void createTMAAsyncCopy( // the same dot operand encoding, return true and get the shared encoding that // needs to be used to be compatible with users' layouts. static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val) { +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; + incompatible = false; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; if (user->getNumResults() != 1) @@ -240,7 +241,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) return std::nullopt; } else { if (!isa(user)) @@ -260,8 +262,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. - if (!tempAttr || (attr != nullptr && attr != tempAttr)) + if (attr != nullptr && attr != tempAttr) { + incompatible = true; return std::nullopt; + } attr = tempAttr; } return attr; @@ -451,9 +455,13 @@ assignMemoryLayouts(llvm::SmallVector> loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (auto dot = dyn_cast(use)) { + bool imcompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); - + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), imcompatible) + .value_or(nullptr); + // If we can't agree on a shared encoding skip pipelinig the load. + if (imcompatible) + continue; // HACK: Triton LLVM codegen has a bug where local_loads from #shared to // #mma layout can lead to invalid code if the loaded shape is smaller // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 1ba1a3f6e60d..a5594b304309 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -671,7 +671,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK-NOT: triton_gpu.insert_slice_async + // check that the load didn't get pipelined. + // CHECK-NOT: alloc // CHECK: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> @@ -1146,14 +1147,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> // Check that both loads in the loop are pipelined. - // TODO(jlebar): https://github.com/triton-lang/triton/pull/3472 disables the - // relevant optimization. Once we've reenabled it, we can uncomment this test. // CHECK: scf.for - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load // CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load - // COM: CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK-NOT: tt.load // CHECK: scf.yield %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { %78 = tt.load %arg11 : tensor<16x128x!tt.ptr, #blocked1> From 4333690facc1b554c853297a1470642eaa83c9b9 Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 24 Jun 2024 08:19:57 -0700 Subject: [PATCH 2/2] Fix typo --- .../Transforms/Pipeliner/MatmulLoopPipeline.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index f45144ebb664..d74e0a224949 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -226,8 +226,9 @@ static void createTMAAsyncCopy( } // If all the transitive uses of the given value have are used by a convert to -// the same dot operand encoding, return true and get the shared encoding that -// needs to be used to be compatible with users' layouts. +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are imcompatible shared +// encodings set `incompatible` to true. static std::optional getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; @@ -455,12 +456,12 @@ assignMemoryLayouts(llvm::SmallVector> loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (auto dot = dyn_cast(use)) { - bool imcompatible = false; + bool incompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0), imcompatible) + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) .value_or(nullptr); // If we can't agree on a shared encoding skip pipelinig the load. - if (imcompatible) + if (incompatible) continue; // HACK: Triton LLVM codegen has a bug where local_loads from #shared to // #mma layout can lead to invalid code if the loaded shape is smaller