From fce53a9a2a661995de0473ef634bb7e63d49025d Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Mon, 26 Aug 2024 14:56:04 -0700 Subject: [PATCH] [BACKEND] Optimize code generation for load with other arg When other is there we should use it to initalize the reg before doing the load instead of initializing the reg with 0. Note that this does add a scoreboard dependency between the other def and the load but user can remove it by using a select if other comes from a high latency op. --- test/Conversion/tritongpu_to_llvm.mlir | 2 + .../LoadStoreOpToLLVM.cpp | 62 ++++++++++--------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index eed4269fa5e2..dce45889d0c9 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -17,6 +17,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: basic_load tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm + // CHECK-SAME: mov.u32 $0, $1; + // CHECK-SAME: @$3 ld.global.b32 { $0 }, [ $2 + 0 ];", "=r,r,l,b" // CHECK: llvm.inline_asm %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> tt.return diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 9ee532992d01..25a176dd1418 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -237,40 +237,14 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); + // If there is a `other` value, use it to init. + bool init = other == nullptr; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint, - /*init=*/true); // =r operations + init); // =r operations dstsOpr->listAppend(opr); } - auto *addrOpr = - ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); - - // Define the instruction opcode - auto &ld = ptxBuilder.create<>("ld") - ->o("volatile", op.getIsVolatile()) - .global() - .o("ca", op.getCache() == triton::CacheModifier::CA) - .o("cg", op.getCache() == triton::CacheModifier::CG) - .o("L1::evict_first", - op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) - .o("L1::evict_last", - op.getEvict() == triton::EvictionPolicy::EVICT_LAST) - .o("L1::cache_hint", hasL2EvictPolicy) - .v(nWords) - .b(width); - - PTXBuilder::Operand *evictOpr{}; - - // Here lack a mlir::Value to bind to this operation, so disabled. - // if (has_l2_evict_policy) - // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); - - if (!evictOpr) - ld(dstsOpr, addrOpr).predicate(pred, "b"); - else - ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); - if (other) { for (size_t ii = 0; ii < nWords; ++ii) { // PTX doesn't support mov.u8, so we need to use mov.u16 @@ -298,10 +272,38 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } else opr = ptxBuilder.newOperand(v, readConstraint); - mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); + mov(dstsOpr->listGet(ii), opr); } } + auto *addrOpr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + + // Define the instruction opcode + auto &ld = ptxBuilder.create<>("ld") + ->o("volatile", op.getIsVolatile()) + .global() + .o("ca", op.getCache() == triton::CacheModifier::CA) + .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", + op.getEvict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); + + PTXBuilder::Operand *evictOpr{}; + + // Here lack a mlir::Value to bind to this operation, so disabled. + // if (has_l2_evict_policy) + // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); + + if (!evictOpr) + ld(dstsOpr, addrOpr).predicate(pred, "b"); + else + ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); + // Create inline ASM signature SmallVector retTys(nWords, IntegerType::get(getContext(), width)); Type retTy = retTys.size() > 1