diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index eed4269fa5e2..dce45889d0c9 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -17,6 +17,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-LABEL: basic_load tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm + // CHECK-SAME: mov.u32 $0, $1; + // CHECK-SAME: @$3 ld.global.b32 { $0 }, [ $2 + 0 ];", "=r,r,l,b" // CHECK: llvm.inline_asm %1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr, #blocked0> tt.return diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 9ee532992d01..25a176dd1418 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -237,40 +237,14 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, // prepare asm operands auto *dstsOpr = ptxBuilder.newListOperand(); + // If there is a `other` value, use it to init. + bool init = other == nullptr; for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) { auto *opr = ptxBuilder.newOperand(writeConstraint, - /*init=*/true); // =r operations + init); // =r operations dstsOpr->listAppend(opr); } - auto *addrOpr = - ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); - - // Define the instruction opcode - auto &ld = ptxBuilder.create<>("ld") - ->o("volatile", op.getIsVolatile()) - .global() - .o("ca", op.getCache() == triton::CacheModifier::CA) - .o("cg", op.getCache() == triton::CacheModifier::CG) - .o("L1::evict_first", - op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) - .o("L1::evict_last", - op.getEvict() == triton::EvictionPolicy::EVICT_LAST) - .o("L1::cache_hint", hasL2EvictPolicy) - .v(nWords) - .b(width); - - PTXBuilder::Operand *evictOpr{}; - - // Here lack a mlir::Value to bind to this operation, so disabled. - // if (has_l2_evict_policy) - // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); - - if (!evictOpr) - ld(dstsOpr, addrOpr).predicate(pred, "b"); - else - ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); - if (other) { for (size_t ii = 0; ii < nWords; ++ii) { // PTX doesn't support mov.u8, so we need to use mov.u16 @@ -298,10 +272,38 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern, } else opr = ptxBuilder.newOperand(v, readConstraint); - mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b"); + mov(dstsOpr->listGet(ii), opr); } } + auto *addrOpr = + ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off); + + // Define the instruction opcode + auto &ld = ptxBuilder.create<>("ld") + ->o("volatile", op.getIsVolatile()) + .global() + .o("ca", op.getCache() == triton::CacheModifier::CA) + .o("cg", op.getCache() == triton::CacheModifier::CG) + .o("L1::evict_first", + op.getEvict() == triton::EvictionPolicy::EVICT_FIRST) + .o("L1::evict_last", + op.getEvict() == triton::EvictionPolicy::EVICT_LAST) + .o("L1::cache_hint", hasL2EvictPolicy) + .v(nWords) + .b(width); + + PTXBuilder::Operand *evictOpr{}; + + // Here lack a mlir::Value to bind to this operation, so disabled. + // if (has_l2_evict_policy) + // evictOpr = ptxBuilder.newOperand(l2Evict, "l"); + + if (!evictOpr) + ld(dstsOpr, addrOpr).predicate(pred, "b"); + else + ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b"); + // Create inline ASM signature SmallVector retTys(nWords, IntegerType::get(getContext(), width)); Type retTy = retTys.size() > 1