diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index de0eb140e2b7..43f0d6a95ba8 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -67,13 +67,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_f16 - tt.func @atomic_add_f16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + // CHECK-LABEL: atomic_add_f16x2 + tt.func @atomic_add_f16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK-NOT: rocdl.update.dpp %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> tt.return } @@ -83,13 +85,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : #blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { - // CHECK-LABEL: atomic_add_bf16 - tt.func @atomic_add_bf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + // CHECK-LABEL: atomic_add_bf16x2 + tt.func @atomic_add_bf16x2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> // CHECK: llvm.cond_br + // CHECK-NOT: rocdl.update.dpp // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK-NOT: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16_dpp + tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + // CHECK: rocdl.update.dpp + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16_dpp + tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK: rocdl.update.dpp + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + // CHECK: rocdl.update.dpp %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> tt.return } diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 43a334d5cfa9..e1c96bee0cc2 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -714,6 +714,32 @@ struct AtomicCASOpConversion } }; +bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) { + return isaFamily == triton::AMD::ISAFamily::CDNA1 || + isaFamily == triton::AMD::ISAFamily::CDNA2 || + isaFamily == triton::AMD::ISAFamily::CDNA3; +} + +Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) { + assert(val.getType().isInteger(32)); + auto loc = val.getLoc(); + Value old = i32_val(0); + int rowMask = 0b1111; // enable all rows + int bankMask = 0b1111; // enable all banks + bool boundCtrl = false; + auto dppMovOp = rewriter.create( + loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl); + return dppMovOp.getResult(); +} + +Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane +} + +Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) { + return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane +} + struct AtomicRMWOpConversion : public ConvertOpToLLVMPattern, public LoadStoreConversionBase { @@ -785,14 +811,36 @@ struct AtomicRMWOpConversion // vec = 1, numElements = 1 for scalar auto vec = getVectorSize(ptr); int numElems = 1; + Type packF16Ty = vec_ty(valueElemTy, 2); + + // In the case of unpaired f16 elements utilize dpp instructions to + // accelerate atomics. Here is an algorithm of lowering + // tt::atomicRmwOp(%ptr, %val, %mask): + // 0. Group thread by pairs. Master thread is (tid % 2 == 0); + // 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so + // all the masters recieve value from secondary threads; + // 2. Take into account parity in the %mask value, build control flow + // structures according to it; + // 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value; + // 4. All the threads send result of generated operation to (tid + 1) thread + // via dppUpdateOp shl, so all secondary thread also recieve their + // result. + // + // This approach enables us to use half the active threads committing atomic + // requests to avoid generating of code providing unified access to f16 + // element and reduce contantion. + bool useDppForPackedF16 = false; // tensor if (tensorTy) { auto valTy = cast(val.getType()); - Type elTy = valTy.getElementType(); - vec = std::min(vec, llvm::isa(elTy) && - elTy.getIntOrFloatBitWidth() == 16 - ? 2 - : 1); + bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16(); + unsigned availableVecSize = isF16Ty ? 2 : 1; + vec = std::min(vec, availableVecSize); + // Force F16 packing in the case it's not comming in as packed, but the + // ISA can support packed atomic instructions. + useDppForPackedF16 = + supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) && + vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD; // mask numElems = tensorTy.getNumElements(); } @@ -800,12 +848,15 @@ struct AtomicRMWOpConversion auto tid = tid_val(); mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + if (useDppForPackedF16) + mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); auto memOrdering = op.getSem(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; + retType = useDppForPackedF16 ? packF16Ty : retType; SmallVector resultVals(elemsPerThread); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; @@ -814,7 +865,24 @@ struct AtomicRMWOpConversion Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; Value operand; - if (vec == 1) { + if (useDppForPackedF16) { + // Move %val to left neighbour to proceed packed atomic further. + Value packedVal = null(packF16Ty); + packedVal = + insert_element(packF16Ty, packedVal, valElements[i], i32_val(0)); + // Pack to i32 type to simplify transaction + packedVal = bitcast(packedVal, i32_ty); + Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal); + // Unpack results back + Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty); + operand = undef(packF16Ty); + operand = + insert_element(packF16Ty, operand, valElements[i], i32_val(0)); + operand = insert_element( + packF16Ty, operand, + extract_element(valueElemTy, unpackedDppRes, i32_val(0)), + i32_val(1)); + } else if (vec == 1) { operand = valElements[i]; } else { operand = undef(vecTy); @@ -856,10 +924,25 @@ struct AtomicRMWOpConversion rewriter.setInsertionPointToStart(endBlock); Value retVal = endBlock->getArgument(0); if (tensorTy) { - for (int ii = 0; ii < vec; ++ii) { - resultVals[i + ii] = - vec == 1 ? retVal - : extract_element(valueElemTy, retVal, i32_val(ii)); + if (useDppForPackedF16) { + // Return packed to i32 result after atomic operation back from master + // lane. + auto packedRet = bitcast(retVal, i32_ty); + Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet); + // Unpack results back + Value unpackedDppRes = bitcast(dppMovRes, packF16Ty); + retVal = insert_element( + packF16Ty, retVal, + extract_element(valueElemTy, unpackedDppRes, i32_val(1)), + i32_val(1)); + resultVals[i] = + extract_element(valueElemTy, retVal, urem(tid, i32_val(2))); + } else { + for (int ii = 0; ii < vec; ++ii) { + resultVals[i + ii] = + vec == 1 ? retVal + : extract_element(valueElemTy, retVal, i32_val(ii)); + } } } else { if (!atomicNeedsSharedMemory(op.getResult())) {