diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 6752881b3350..d74e0a224949 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -226,11 +226,13 @@ static void createTMAAsyncCopy( } // If all the transitive uses of the given value have are used by a convert to -// the same dot operand encoding, return true and get the shared encoding that -// needs to be used to be compatible with users' layouts. +// the same dot operand encoding, return the shared encoding that needs to be +// used to be compatible with users' layouts. If there are imcompatible shared +// encodings set `incompatible` to true. static std::optional -getSharedEncIfAllUsersAreDotEnc(Value val) { +getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { ttg::SharedEncodingAttr attr; + incompatible = false; for (Operation *user : val.getUsers()) { ttg::SharedEncodingAttr tempAttr; if (user->getNumResults() != 1) @@ -240,7 +242,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { // First time we find a shared encoding in the chain, save it and try to // use it if it is compatible with the other users. tempAttr = cast(memDesc.getEncoding()); - if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value()) + if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible) + .has_value()) return std::nullopt; } else { if (!isa(user)) @@ -260,8 +263,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) { srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false); } // Check that the shared encodings needed by the users are compatible. - if (!tempAttr || (attr != nullptr && attr != tempAttr)) + if (attr != nullptr && attr != tempAttr) { + incompatible = true; return std::nullopt; + } attr = tempAttr; } return attr; @@ -451,9 +456,13 @@ assignMemoryLayouts(llvm::SmallVector> loadInfo.sharedEncoding = getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr); } else if (auto dot = dyn_cast(use)) { + bool incompatible = false; loadInfo.sharedEncoding = - getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); - + getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible) + .value_or(nullptr); + // If we can't agree on a shared encoding skip pipelinig the load. + if (incompatible) + continue; // HACK: Triton LLVM codegen has a bug where local_loads from #shared to // #mma layout can lead to invalid code if the loaded shape is smaller // than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 1ba1a3f6e60d..a5594b304309 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -671,7 +671,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %14 = tt.broadcast %11 : tensor<1x16x!tt.ptr, #blocked> -> tensor<64x16x!tt.ptr, #blocked> %15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked> %16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr, #blocked>, tensor<64x16xi32, #blocked> - // CHECK-NOT: triton_gpu.insert_slice_async + // check that the load didn't get pipelined. + // CHECK-NOT: alloc // CHECK: scf.for %17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 { %18 = tt.load %16 : tensor<64x16x!tt.ptr, #blocked> @@ -1146,14 +1147,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 : %51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr, #blocked>, tensor<64x256xi32, #blocked> // Check that both loads in the loop are pipelined. - // TODO(jlebar): https://github.com/triton-lang/triton/pull/3472 disables the - // relevant optimization. Once we've reenabled it, we can uncomment this test. // CHECK: scf.for - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load // CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load - // COM: CHECK: triton_gpu.async_copy_global_to_local - // COM: CHECK-NOT: tt.load + // CHECK-NOT: tt.load + // CHECK: triton_gpu.async_copy_global_to_local + // CHECK-NOT: tt.load // CHECK: scf.yield %54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr, #blocked1>, tensor<64x256x!tt.ptr, #blocked>) : i32 { %78 = tt.load %arg11 : tensor<16x128x!tt.ptr, #blocked1>