Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,13 +399,13 @@ struct AtomicCASOpConversion
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());

auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
Expand All @@ -425,7 +425,7 @@ struct AtomicCASOpConversion
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();

Expand All @@ -434,7 +434,7 @@ struct AtomicCASOpConversion
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
st(dstOprStore, valOprStore).predicate(mask);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
Expand Down Expand Up @@ -483,10 +483,11 @@ struct AtomicRMWOpConversion
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());

auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
: valueTy;
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
Expand All @@ -499,10 +500,7 @@ struct AtomicRMWOpConversion
// mask
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one adds a cmp eq, so it changes the order of operations, and thus the lit test order changes


auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
Expand Down Expand Up @@ -582,7 +580,6 @@ struct AtomicRMWOpConversion
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one seems redundant after the one I added in getMask

atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ class ConvertTritonGPUOpToLLVMPatternBase {
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 can
// write
mask = and_(mask, icmp_slt(tid, i32_val(1)));
mask = and_(mask, icmp_eq(tid, i32_val(0)));
}
return mask;
}
Expand Down
4 changes: 2 additions & 2 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
// CHECK: llvm.inline_asm
Expand All @@ -1026,6 +1025,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why I had to change this, will take a look

// CHECK-SAME: @$3 atom.global.gpu.add.f32
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
tt.return
Expand All @@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32_scalar
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
// CHECK: llvm.icmp "slt"
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
tt.store %arg0, %arg1 : f32
Expand Down