Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,13 @@ static void createTMAAsyncCopy(
}

// If all the transitive uses of the given value have are used by a convert to
// the same dot operand encoding, return true and get the shared encoding that
// needs to be used to be compatible with users' layouts.
// the same dot operand encoding, return the shared encoding that needs to be
// used to be compatible with users' layouts. If there are imcompatible shared
// encodings set `incompatible` to true.
static std::optional<ttg::SharedEncodingAttr>
getSharedEncIfAllUsersAreDotEnc(Value val) {
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
ttg::SharedEncodingAttr attr;
incompatible = false;
for (Operation *user : val.getUsers()) {
ttg::SharedEncodingAttr tempAttr;
if (user->getNumResults() != 1)
Expand All @@ -240,7 +242,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
// First time we find a shared encoding in the chain, save it and try to
// use it if it is compatible with the other users.
tempAttr = cast<ttg::SharedEncodingAttr>(memDesc.getEncoding());
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0)).has_value())
if (!getSharedEncIfAllUsersAreDotEnc(user->getResult(0), incompatible)
.has_value())
return std::nullopt;
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
Expand All @@ -260,8 +263,10 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
}
// Check that the shared encodings needed by the users are compatible.
if (!tempAttr || (attr != nullptr && attr != tempAttr))
if (attr != nullptr && attr != tempAttr) {
incompatible = true;
return std::nullopt;
}
attr = tempAttr;
}
return attr;
Expand Down Expand Up @@ -451,9 +456,13 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
loadInfo.sharedEncoding =
getSharedEncoding(op, /*loadIsMMAv3=*/true).value_or(nullptr);
} else if (auto dot = dyn_cast<tt::DotOp>(use)) {
bool incompatible = false;
loadInfo.sharedEncoding =
getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr);

getSharedEncIfAllUsersAreDotEnc(op->getResult(0), incompatible)
.value_or(nullptr);
// If we can't agree on a shared encoding skip pipelinig the load.
if (incompatible)
continue;
// HACK: Triton LLVM codegen has a bug where local_loads from #shared to
// #mma layout can lead to invalid code if the loaded shape is smaller
// than the mma tile (e.g. loading a 128x1 tensor for an MMAv2 dot with
Expand Down
13 changes: 6 additions & 7 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,8 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
%14 = tt.broadcast %11 : tensor<1x16x!tt.ptr<f16>, #blocked> -> tensor<64x16x!tt.ptr<f16>, #blocked>
%15 = tt.broadcast %13 : tensor<64x1xi32, #blocked> -> tensor<64x16xi32, #blocked>
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16>, #blocked>, tensor<64x16xi32, #blocked>
// CHECK-NOT: triton_gpu.insert_slice_async
// check that the load didn't get pipelined.
// CHECK-NOT: alloc
// CHECK: scf.for
%17:2 = scf.for %arg2 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg3 = %cst_1, %arg4 = %cst_2) -> (tensor<128x16xf32, #mma>, tensor<128x64xf32, #mma>) : i32 {
%18 = tt.load %16 : tensor<64x16x!tt.ptr<f16>, #blocked>
Expand Down Expand Up @@ -1146,14 +1147,12 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
%51 = tt.addptr %50, %47 : tensor<64x256x!tt.ptr<i8>, #blocked>, tensor<64x256xi32, #blocked>

// Check that both loads in the loop are pipelined.
// TODO(jlebar): https://github.com/triton-lang/triton/pull/3472 disables the
// relevant optimization. Once we've reenabled it, we can uncomment this test.
// CHECK: scf.for
// COM: CHECK-NOT: tt.load
// CHECK-NOT: tt.load
// CHECK: triton_gpu.async_copy_global_to_local
// COM: CHECK-NOT: tt.load
// COM: CHECK: triton_gpu.async_copy_global_to_local
// COM: CHECK-NOT: tt.load
// CHECK-NOT: tt.load
// CHECK: triton_gpu.async_copy_global_to_local
// CHECK-NOT: tt.load
// CHECK: scf.yield
%54:3 = scf.for %arg9 = %c0_i32 to %c16_i32 step %c1_i32 iter_args(%arg10 = %cst_3, %arg11 = %41, %arg12 = %51) -> (tensor<16x256xf32, #mma>, tensor<16x128x!tt.ptr<f16>, #blocked1>, tensor<64x256x!tt.ptr<i8>, #blocked>) : i32 {
%78 = tt.load %arg11 : tensor<16x128x!tt.ptr<f16>, #blocked1>
Expand Down