Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "triton/Analysis/Utility.h"

#include <fstream>
#include <optional>

#include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
#include "mlir/Analysis/SliceAnalysis.h"
Expand Down Expand Up @@ -119,6 +120,55 @@ unsigned getElementBitWidth(RankedTensorType type) {
return typeForMem.getIntOrFloatBitWidth();
}

static std::optional<unsigned>
getAtomicWriteElementsPerThreadCap(Operation *op) {
if (isa<triton::AtomicCASOp>(op))
return 1;

auto atomicRmw = dyn_cast<triton::AtomicRMWOp>(op);
if (!atomicRmw)
return std::nullopt;

Type elemTy = getElementTypeOrSelf(atomicRmw.getVal().getType());
if (elemTy.isInteger() || elemTy.isF64())
return 1;

if (atomicRmw.getAtomicRmwOp() != RMWOp::FADD)
return std::nullopt;

auto moduleOp = op->getParentOfType<ModuleOp>();
auto targetAttr =
moduleOp ? moduleOp->getAttrOfType<StringAttr>(ttg::AttrTargetName)
: nullptr;
if (!targetAttr || !targetAttr.getValue().starts_with("cuda:"))
return std::nullopt;

int computeCapability = getNVIDIAComputeCapability(moduleOp);
if (computeCapability >= 90)
return std::nullopt;
Comment on lines +146 to +148
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, what changed in Hopper?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vectorised atomics were added with Hopper, its buried in https://docs.nvidia.com/cuda/parallel-thread-execution/

"Support for vector types requires sm_90 or higher."

See also a similar-ish discussion: llvm/llvm-project#122760


if (elemTy.isF32() || elemTy.isBF16())
return 1;
if (elemTy.isF16())
return 2;
return std::nullopt;
}

static unsigned getMaxElementsPerThread(Operation *op) {
Value val = getMemAccessPtr(op);
auto ty = cast<RankedTensorType>(val.getType());
unsigned elemNumBits = getElementBitWidth(ty);
unsigned maxElementsPerThread = 128 / elemNumBits;
// Some atomic lowerings are narrower than a plain store. TTGIR currently
// exposes the target architecture but not the PTX version, so we only cap
// cases that are unambiguous from the available target metadata and the
// current backend lowering.
if (auto atomicCap = getAtomicWriteElementsPerThreadCap(op)) {
maxElementsPerThread = std::min(maxElementsPerThread, *atomicCap);
}
return maxElementsPerThread;
}

unsigned getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
ModuleAxisInfoAnalysis &axisInfoAnalysis,
ArrayRef<int64_t> shapePerCTA) {
Expand All @@ -132,10 +182,12 @@ unsigned getNumElementsPerThread(Operation *op, SmallVector<unsigned> order,
unsigned maxContig =
std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned currPerThread = std::min(alignment, 128 / elemNumBits);
unsigned maxElementsPerThread = getMaxElementsPerThread(op);
unsigned currPerThread = std::min(alignment, maxElementsPerThread);
LDBG("elemNumBytes: " << elemNumBytes
<< ", divisibility: " << maxMultipleBytes
<< ", contig: " << valInfo.getContiguity(order[0])
<< ", maximum: " << maxElementsPerThread
<< ", alignment: " << alignment);
return currPerThread;
}
Expand Down
71 changes: 71 additions & 0 deletions test/TritonGPU/coalesce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,77 @@ tt.func public @load_tensors_two_types(%arg0: !tt.ptr<f32> {tt.divisibility = 16

// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @atomic_add_i32
// CHECK-NOT: sizePerThread = [4]
// CHECK: tt.atomic_rmw add, relaxed, gpu, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xi32, #blocked>
// CHECK-NOT: sizePerThread = [4]
tt.func public @atomic_add_i32(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}, %arg1: i32) {
%c1024_i32 = arith.constant 1024 : i32
%c1_i32 = arith.constant dense<1> : tensor<1024xi32, #blocked>
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg1 : i32 -> tensor<1024xi32, #blocked>
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<1024x!tt.ptr<i32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.atomic_rmw add, relaxed, gpu, %8, %c1_i32, %6 : (tensor<1024x!tt.ptr<i32>, #blocked>, tensor<1024xi32, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xi32, #blocked>
tt.return
}
}
// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @atomic_add_f32_cuda80
// CHECK-NOT: sizePerThread = [4]
// CHECK: tt.atomic_rmw fadd, relaxed, gpu, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xf32, #blocked>
// CHECK-NOT: sizePerThread = [4]
tt.func public @atomic_add_f32_cuda80(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32) {
%c1024_i32 = arith.constant 1024 : i32
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked>
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg1 : i32 -> tensor<1024xi32, #blocked>
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.atomic_rmw fadd, relaxed, gpu, %8, %cst, %6 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xf32, #blocked>
tt.return
}
}
// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:80", "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @atomic_add_f16_cuda80
// CHECK: ttg.convert_layout %{{.*}} : tensor<1024x!tt.ptr<f16>, #blocked> -> tensor<1024x!tt.ptr<f16>, #[[ATOMIC_F16_LAYOUT:.*]]>
// CHECK: tt.atomic_rmw fadd, relaxed, gpu, %{{.*}}, %{{.*}}, %{{.*}} : (tensor<1024x!tt.ptr<f16>, #[[ATOMIC_F16_LAYOUT]]>, tensor<1024xf16, #[[ATOMIC_F16_LAYOUT]]>, tensor<1024xi1, #[[ATOMIC_F16_LAYOUT]]>) -> tensor<1024xf16, #[[ATOMIC_F16_LAYOUT]]>
tt.func public @atomic_add_f16_cuda80(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32) {
%c1024_i32 = arith.constant 1024 : i32
%cst = arith.constant dense<1.000000e+00> : tensor<1024xf16, #blocked>
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
%5 = tt.splat %arg1 : i32 -> tensor<1024xi32, #blocked>
%6 = arith.cmpi "slt", %4, %5 : tensor<1024xi32, #blocked>
%7 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<1024x!tt.ptr<f16>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.atomic_rmw fadd, relaxed, gpu, %8, %cst, %6 : (tensor<1024x!tt.ptr<f16>, #blocked>, tensor<1024xf16, #blocked>, tensor<1024xi1, #blocked>) -> tensor<1024xf16, #blocked>
tt.return
}
}
// -----

// COM: Reproducer for issue #5122
// CHECK-LABEL: @test_5122
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
Expand Down
Loading