Skip to content

Commit

Permalink
Rebase after PR #4369
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseros committed Sep 5, 2024
1 parent b6c3d0a commit b78315e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 21 deletions.
10 changes: 4 additions & 6 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using namespace mlir::triton::gpu;

using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryBase;
using ::mlir::LLVM::AMD::llLoad;
using ::mlir::LLVM::AMD::llStore;
using ::mlir::triton::gpu::getTotalElemsPerThread;

namespace {
Expand Down Expand Up @@ -311,18 +313,14 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
const size_t wordNElems = width / valueElemNBits;
assert(wordNElems * nWords * numVecs == elemsPerThread);

// TODO(Superjomn) Add cache policy fields to StoreOp.
// TODO(Superjomn) Deal with cache policy here.

Type valArgTy = IntegerType::get(ctx, width);
auto wordTy = vec_ty(valueElemTy, wordNElems);

SmallVector<std::pair<Value, std::string>> asmArgs;
Value elem = valueElems[vecStart];
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);


// // Create the store val
// Create the store val
Value storeVal = undef(vecTy);
for (size_t s = 0; s < vec; ++s) {
Value otherElem = valueElems[vecStart + s];
Expand All @@ -331,7 +329,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
storeVal = insert_element(vecTy, storeVal, otherElem, indexVal);
}
llStore(rewriter, loc, ptr, storeVal, pred, cacheMod);
}
} // end vec
rewriter.eraseOp(op);
return success();
}
Expand Down
26 changes: 11 additions & 15 deletions third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@ std::string mangleFunc(std::string name, Type type) {
}

// Create a constant vector mask of length `vec` with the same `pred` value
Value createVectorMaskFromPredicate(RewriterBase& rewriter, Location loc,
Value createVectorMaskFromPredicate(RewriterBase &rewriter, Location loc,
Value pred, int64_t vec) {
auto vecMaskTy = LLVM::getFixedVectorType(rewriter.getI1Type(), vec);
Value maskVal = undef(vecMaskTy);
for (size_t s = 0; s < vec; ++s) {
Value indexVal = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64IntegerAttr(s));
Value indexVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64IntegerAttr(s));
maskVal = insert_element(vecMaskTy, maskVal, pred, indexVal);
}
return maskVal;
Expand Down Expand Up @@ -173,15 +174,12 @@ Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp,
Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,
Value pred, Value falseVal, triton::CacheModifier cm) {

// If we not volatile, then use llvm.masked.load
if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE){
if (cm == triton::CacheModifier::CG || cm == triton::CacheModifier::NONE) {
int64_t vec = cast<VectorType>(elemTy).getNumElements();
Value maskVal = createVectorMaskFromPredicate(
rewriter, loc, pred, vec);
bool nt = (cm ==triton::CacheModifier::CG);
return rewriter.create<LLVM::MaskedLoadOp>(
loc, elemTy, ptr, maskVal, falseVal, vec, nt);

Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vec);
bool nt = (cm == triton::CacheModifier::CG);
return rewriter.create<LLVM::MaskedLoadOp>(loc, elemTy, ptr, maskVal,
falseVal, vec, nt);
}

Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal}));
Expand Down Expand Up @@ -212,12 +210,10 @@ Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy,

void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val,
Value pred, triton::CacheModifier cm) {
Type elemTy = ptr.getType();
// If we use no cache modifier, simply issue a masked store
if (cm == triton::CacheModifier::NONE){
if (cm == triton::CacheModifier::NONE) {
Type elemTy = val.getType();
int64_t vec = cast<VectorType>(elemTy).getNumElements();
Value maskVal = createVectorMaskFromPredicate(
rewriter, loc, pred, vec);
Value maskVal = createVectorMaskFromPredicate(rewriter, loc, pred, vec);
rewriter.create<LLVM::MaskedStoreOp>(loc, val, ptr, maskVal, vec);
}
auto ctx = ptr.getContext();
Expand Down

0 comments on commit b78315e

Please sign in to comment.