Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f16
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
// CHECK-LABEL: atomic_add_f16x2
tt.func @atomic_add_f16x2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
// CHECK: llvm.cond_br
// CHECK-NOT: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
// CHECK-NOT: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
tt.return
}
Expand All @@ -83,13 +85,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :

#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_bf16
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
// CHECK-LABEL: atomic_add_bf16x2
tt.func @atomic_add_bf16x2(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
// CHECK: llvm.cond_br
// CHECK-NOT: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
// CHECK-NOT: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
tt.return
}
}

// -----

#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f16_dpp
tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
// CHECK: llvm.cond_br
// CHECK: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
// CHECK: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
tt.return
}
}

// -----

#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_bf16_dpp
tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
// CHECK: llvm.cond_br
// CHECK: rocdl.update.dpp
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
// CHECK: rocdl.update.dpp
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
tt.return
}
Expand Down
103 changes: 93 additions & 10 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,32 @@ struct AtomicCASOpConversion
}
};

bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) {
return isaFamily == triton::AMD::ISAFamily::CDNA1 ||
isaFamily == triton::AMD::ISAFamily::CDNA2 ||
isaFamily == triton::AMD::ISAFamily::CDNA3;
}

Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) {
assert(val.getType().isInteger(32));
auto loc = val.getLoc();
Value old = i32_val(0);
int rowMask = 0b1111; // enable all rows
int bankMask = 0b1111; // enable all banks
bool boundCtrl = false;
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl);
return dppMovOp.getResult();
}

Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) {
return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane
}

Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) {
return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane
}

struct AtomicRMWOpConversion
: public ConvertOpToLLVMPattern<triton::AtomicRMWOp>,
public LoadStoreConversionBase {
Expand Down Expand Up @@ -785,27 +811,52 @@ struct AtomicRMWOpConversion
// vec = 1, numElements = 1 for scalar
auto vec = getVectorSize(ptr);
int numElems = 1;
Type packF16Ty = vec_ty(valueElemTy, 2);

// In the case of unpaired f16 elements utilize dpp instructions to
// accelerate atomics. Here is an algorithm of lowering
// tt::atomicRmwOp(%ptr, %val, %mask):
// 0. Group thread by pairs. Master thread is (tid % 2 == 0);
// 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
// all the masters recieve value from secondary threads;
// 2. Take into account parity in the %mask value, build control flow
// structures according to it;
// 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
// 4. All the threads send result of generated operation to (tid + 1) thread
// via dppUpdateOp shl, so all secondary thread also recieve their
// result.
//
// This approach enables us to use half the active threads committing atomic
// requests to avoid generating of code providing unified access to f16
// element and reduce contantion.
bool useDppForPackedF16 = false;
// tensor
if (tensorTy) {
auto valTy = cast<RankedTensorType>(val.getType());
Type elTy = valTy.getElementType();
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
elTy.getIntOrFloatBitWidth() == 16
? 2
: 1);
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
unsigned availableVecSize = isF16Ty ? 2 : 1;
vec = std::min<unsigned>(vec, availableVecSize);
// Force F16 packing in the case it's not comming in as packed, but the
// ISA can support packed atomic instructions.
useDppForPackedF16 =
supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) &&
vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
// mask
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
if (useDppForPackedF16)
mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0)));

auto memOrdering = op.getSem();
auto atomicMemOrdering = getMemoryOrdering(memOrdering);

auto vecTy = vec_ty(valueElemTy, vec);
auto retType = vec == 1 ? valueElemTy : vecTy;
retType = useDppForPackedF16 ? packF16Ty : retType;
SmallVector<Value> resultVals(elemsPerThread);
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwPtr = ptrElements[i];
Expand All @@ -814,7 +865,24 @@ struct AtomicRMWOpConversion
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;

Value operand;
if (vec == 1) {
if (useDppForPackedF16) {
// Move %val to left neighbour to proceed packed atomic further.
Value packedVal = null(packF16Ty);
packedVal =
insert_element(packF16Ty, packedVal, valElements[i], i32_val(0));
// Pack to i32 type to simplify transaction
packedVal = bitcast(packedVal, i32_ty);
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
// Unpack results back
Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty);
operand = undef(packF16Ty);
operand =
insert_element(packF16Ty, operand, valElements[i], i32_val(0));
operand = insert_element(
packF16Ty, operand,
extract_element(valueElemTy, unpackedDppRes, i32_val(0)),
i32_val(1));
} else if (vec == 1) {
operand = valElements[i];
} else {
operand = undef(vecTy);
Expand Down Expand Up @@ -856,10 +924,25 @@ struct AtomicRMWOpConversion
rewriter.setInsertionPointToStart(endBlock);
Value retVal = endBlock->getArgument(0);
if (tensorTy) {
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? retVal
: extract_element(valueElemTy, retVal, i32_val(ii));
if (useDppForPackedF16) {
// Return packed to i32 result after atomic operation back from master
// lane.
auto packedRet = bitcast(retVal, i32_ty);
Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet);
// Unpack results back
Value unpackedDppRes = bitcast(dppMovRes, packF16Ty);
retVal = insert_element(
packF16Ty, retVal,
extract_element(valueElemTy, unpackedDppRes, i32_val(1)),
i32_val(1));
resultVals[i] =
extract_element(valueElemTy, retVal, urem(tid, i32_val(2)));
} else {
for (int ii = 0; ii < vec; ++ii) {
resultVals[i + ii] =
vec == 1 ? retVal
: extract_element(valueElemTy, retVal, i32_val(ii));
}
}
} else {
if (!atomicNeedsSharedMemory(op.getResult())) {
Expand Down