From 97ca961aca8c612372347f0c63483e95017626ff Mon Sep 17 00:00:00 2001 From: Zahi Moudallal Date: Thu, 4 May 2023 14:18:11 -0700 Subject: [PATCH] [BACKEND] Update predicate for atomic ops --- .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 23 ++++++++----------- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 2 +- test/Conversion/tritongpu_to_llvm.mlir | 4 ++-- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 7973f63a0115..2846f137b2b3 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -399,13 +399,13 @@ struct AtomicCASOpConversion auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, op.getVal().getType()); - auto TensorTy = op.getResult().getType().dyn_cast(); + auto valueTy = op.getResult().getType(); + auto TensorTy = valueTy.dyn_cast(); Type valueElemTy = TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType()) - : op.getResult().getType(); + : valueTy; auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); - auto tid = tid_val(); - Value pred = icmp_eq(tid, i32_val(0)); + Value mask = getMask(valueTy, rewriter, loc); PTXBuilder ptxBuilderMemfence; auto memfence = ptxBuilderMemfence.create("membar")->o("gl"); memfence(); @@ -425,7 +425,7 @@ struct AtomicCASOpConversion auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r"); auto &atom = *ptxBuilderAtomicCAS.create("atom"); atom.global().o("cas").o("b32"); - atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred); + atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); barrier(); @@ -434,7 +434,7 @@ struct AtomicCASOpConversion auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); auto &st = *ptxBuilderStore.create("st"); st.shared().o("b32"); - st(dstOprStore, valOprStore).predicate(pred); + st(dstOprStore, valOprStore).predicate(mask); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); barrier(); @@ -483,10 +483,11 @@ struct AtomicRMWOpConversion maskElements = getTypeConverter()->unpackLLElements( loc, llMask, rewriter, op.getMask().getType()); - auto tensorTy = op.getResult().getType().dyn_cast(); + auto valueTy = op.getResult().getType(); + auto tensorTy = valueTy.dyn_cast(); Type valueElemTy = tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) - : op.getResult().getType(); + : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar @@ -499,10 +500,7 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } - Value mask = int_val(1, 1); - auto tid = tid_val(); - mask = and_(mask, - icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + Value mask = getMask(valueTy, rewriter, loc); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -582,7 +580,6 @@ struct AtomicRMWOpConversion memfenc(); auto ASMReturnTy = void_ty(ctx); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); - rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0))); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index cdfa731a7fb3..871369e41cfb 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -457,7 +457,7 @@ class ConvertTritonGPUOpToLLVMPatternBase { } else { // If the tensor is not ranked, then it is a scalar and only thread 0 can // write - mask = and_(mask, icmp_slt(tid, i32_val(1))); + mask = and_(mask, icmp_eq(tid, i32_val(0))); } return mask; } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 342310eceed5..8d5db3141bbf 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1009,7 +1009,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { - // CHECK: llvm.icmp "slt" // CHECK: llvm.inline_asm // CHECK-SAME: @$3 atom.global.gpu.add.f32 // CHECK: llvm.inline_asm @@ -1026,6 +1025,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { // CHECK: llvm.icmp "eq" // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm // CHECK-SAME: @$3 atom.global.gpu.add.f32 %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr, f32, i1) -> f32 tt.return @@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32_scalar tt.func @store_f32_scalar(%arg0 : !tt.ptr, %arg1 : f32) { - // CHECK: llvm.icmp "slt" + // CHECK: llvm.icmp "eq" // CHECK: llvm.inline_asm // CHECK-SAME: @$2 st.global.b32 tt.store %arg0, %arg1 : f32