Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-LABEL: basic_load
tt.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: mov.u32 $0, $1;
// CHECK-SAME: @$3 ld.global.b32 { $0 }, [ $2 + 0 ];", "=r,r,l,b"
// CHECK: llvm.inline_asm
%1 = tt.load %a_ptr_init, %cst, %cst_0 : tensor<256x!tt.ptr<f32>, #blocked0>
tt.return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,40 +237,14 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,

// prepare asm operands
auto *dstsOpr = ptxBuilder.newListOperand();
// If there is a `other` value, use it to init.
bool init = other == nullptr;
for (size_t wordIdx = 0; wordIdx < nWords; ++wordIdx) {
auto *opr = ptxBuilder.newOperand(writeConstraint,
/*init=*/true); // =r operations
init); // =r operations
dstsOpr->listAppend(opr);
}

auto *addrOpr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);

// Define the instruction opcode
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.getIsVolatile())
.global()
.o("ca", op.getCache() == triton::CacheModifier::CA)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);

PTXBuilder::Operand *evictOpr{};

// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");

if (!evictOpr)
ld(dstsOpr, addrOpr).predicate(pred, "b");
else
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");

if (other) {
for (size_t ii = 0; ii < nWords; ++ii) {
// PTX doesn't support mov.u8, so we need to use mov.u16
Expand Down Expand Up @@ -298,10 +272,38 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
} else
opr = ptxBuilder.newOperand(v, readConstraint);

mov(dstsOpr->listGet(ii), opr).predicateNot(pred, "b");
mov(dstsOpr->listGet(ii), opr);
}
}

auto *addrOpr =
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);

// Define the instruction opcode
auto &ld = ptxBuilder.create<>("ld")
->o("volatile", op.getIsVolatile())
.global()
.o("ca", op.getCache() == triton::CacheModifier::CA)
.o("cg", op.getCache() == triton::CacheModifier::CG)
.o("L1::evict_first",
op.getEvict() == triton::EvictionPolicy::EVICT_FIRST)
.o("L1::evict_last",
op.getEvict() == triton::EvictionPolicy::EVICT_LAST)
.o("L1::cache_hint", hasL2EvictPolicy)
.v(nWords)
.b(width);

PTXBuilder::Operand *evictOpr{};

// Here lack a mlir::Value to bind to this operation, so disabled.
// if (has_l2_evict_policy)
// evictOpr = ptxBuilder.newOperand(l2Evict, "l");

if (!evictOpr)
ld(dstsOpr, addrOpr).predicate(pred, "b");
else
ld(dstsOpr, addrOpr, evictOpr).predicate(pred, "b");

// Create inline ASM signature
SmallVector<Type> retTys(nWords, IntegerType::get(getContext(), width));
Type retTy = retTys.size() > 1
Expand Down