Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/GatherOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void GatherOpConversion::emitGatherInShared(
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
// Emit the offset into the shared memory and then store the value.
Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset);
b.store(value, ptr);
targetInfo.storeShared(rewriter, loc, ptr, value, b.true_val());
}

// Synchronize the whole CTA.
Expand Down Expand Up @@ -127,7 +127,8 @@ void GatherOpConversion::emitGatherInShared(
indices[axis] = convertIndexToI32(loc, idx, rewriter);
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset);
results[i] = b.load(elemType, ptr);
results[i] =
targetInfo.loadShared(rewriter, loc, ptr, elemType, b.true_val());
}

Value packed =
Expand Down
19 changes: 10 additions & 9 deletions lib/Conversion/TritonGPUToLLVM/HistogramOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@ static void atomicAddOne(Value ptr, Location loc,
b.i32_val(1), LLVM::AtomicOrdering::monotonic);
}

static SmallVector<Value>
computeHistogram(Location loc, ConversionPatternRewriter &rewriter,
Value baseSharedMemPtr, const SmallVector<Value> &srcValues,
const SmallVector<Value> &maskValues, int numBins,
int numThreadPerWarp, const SmallVector<Value> &indices,
Value threadId, int numWarps) {
static SmallVector<Value> computeHistogram(
Location loc, ConversionPatternRewriter &rewriter, Value baseSharedMemPtr,
const SmallVector<Value> &srcValues, const SmallVector<Value> &maskValues,
int numBins, int numThreadPerWarp, const SmallVector<Value> &indices,
Value threadId, int numWarps, const TargetInfoBase &targetInfo) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
SmallVector<Value> histogramValues;
// Initialize the shared memory with zeros.
Expand All @@ -31,7 +30,8 @@ computeHistogram(Location loc, ConversionPatternRewriter &rewriter,
offset = b.urem(offset, b.i32_val(numBins));
Value sharedMemPtr =
b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, offset);
b.store(b.i32_val(0), sharedMemPtr);
targetInfo.storeShared(rewriter, loc, sharedMemPtr, b.i32_val(0),
b.true_val());
}
b.barrier(triton::gpu::AddrSpace::Local);

Expand All @@ -57,7 +57,8 @@ computeHistogram(Location loc, ConversionPatternRewriter &rewriter,
for (Value index : indices) {
Value sharedMemPtr =
b.gep(baseSharedMemPtr.getType(), i32_ty, baseSharedMemPtr, index);
Value val = b.load(i32_ty, sharedMemPtr);
Value val = targetInfo.loadShared(rewriter, loc, sharedMemPtr, i32_ty,
b.true_val());
histogramValues.push_back(val);
}
return histogramValues;
Expand Down Expand Up @@ -111,7 +112,7 @@ struct HistogramOpConversion
innerDimIndices.push_back(indices[i][0]);
SmallVector<Value> histogramValue = computeHistogram(
loc, rewriter, baseSharedMemPtr, srcValues, maskValues, numBins,
numThreadsPerWarp, innerDimIndices, threadId, numWarps);
numThreadsPerWarp, innerDimIndices, threadId, numWarps, targetInfo);

// Depending on the layout, some threads may have duplicate data. We can
// account for this by calculating a "replication factor" and dividing the
Expand Down
22 changes: 11 additions & 11 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ using namespace mlir::triton::gpu;
// Helper for LocalGather/ScatterOpConversion.
// For gather: storeVals is empty, returns loaded values.
// For scatter: storeVals contains values to store, returns empty.
SmallVector<Value> lowerLocalScGt(Location loc, MLIRContext *ctx,
MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues,
ArrayRef<SmallVector<Value>> coords,
unsigned axis, ArrayRef<Value> storeVals,
RewriterBase &rewriter) {
SmallVector<Value>
lowerLocalScGt(Location loc, MLIRContext *ctx, MemDescType memDescTy,
SharedMemoryObject smemObj, Type llvmElemTy,
ArrayRef<Value> idxValues, ArrayRef<SmallVector<Value>> coords,
unsigned axis, ArrayRef<Value> storeVals, RewriterBase &rewriter,
const TargetInfoBase &targetInfo) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
bool isScatter = !storeVals.empty();

Expand Down Expand Up @@ -110,9 +109,10 @@ SmallVector<Value> lowerLocalScGt(Location loc, MLIRContext *ctx,
}

if (isScatter) {
b.store(storeVals[i], ptr);
targetInfo.storeShared(rewriter, loc, ptr, storeVals[i], b.true_val());
} else {
results[i] = b.load(llvmElemTy, ptr);
results[i] =
targetInfo.loadShared(rewriter, loc, ptr, llvmElemTy, b.true_val());
}
}

Expand Down Expand Up @@ -375,7 +375,7 @@ struct LocalGatherOpConversion : public ConvertOpToLLVMPattern<LocalGatherOp> {

auto results = lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy,
idxValues, dstIndices, op.getAxis(),
/*storeVals=*/{}, rewriter);
/*storeVals=*/{}, rewriter, targetInfo);

Value result = packLLElements(loc, typeConverter, results, rewriter, regTy);
rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -425,7 +425,7 @@ struct LocalScatterOpConversion
/*withCTAOffset=*/true);

lowerLocalScGt(loc, ctx, memDescTy, smemObj, llvmElemTy, idxValues,
srcIndices, op.getAxis(), values, rewriter);
srcIndices, op.getAxis(), values, rewriter, targetInfo);

rewriter.eraseOp(op);
return success();
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
auto elemTy = smemTypes[j];
Value ptr = b.gep(smemBases[j].getType(), elemTy, smemBases[j], index);
partialReduce[j] = b.load(elemTy, ptr);
partialReduce[j] =
targetInfo.loadShared(rewriter, loc, ptr, elemTy, b.true_val());
}

if (accumulator.acc.size() == 0) {
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ lowerLdSt(Location loc, MLIRContext *ctx, LinearLayout cvt,
auto vals = to_vector(valsArray);
bool isStore = !vals.empty();
auto b = TritonLLVMOpBuilder(loc, rewriter);
auto smemPtrTy = ptr_ty(ctx, 3);
auto smemPtrTy = ptr_ty(ctx, targetInfo.getSharedAddressSpace());
auto kReg = str_attr("register");
auto kLane = str_attr("lane");
auto kWarp = str_attr("warp");
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/TritonGPUToLLVM/WarpSpecializeUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ static void rewritePartitionRegions(WarpSpecializeOp ws, Block *switchLoop,
/*isPacked=*/true);
Value capturePtr =
LLVM::getSharedMemoryBase(b.getLoc(), b, targetInfo, ws);
LLVM::LLVMPointerType ptrTy = ptr_ty(b.getContext(), 3);
LLVM::LLVMPointerType ptrTy =
ptr_ty(b.getContext(), targetInfo.getSharedAddressSpace());
for (auto [i, arg] :
llvm::zip(llvm::seq<int32_t>(partition->getNumArguments()),
partition->getArguments())) {
Expand Down Expand Up @@ -403,7 +404,7 @@ LogicalResult mlir::triton::lowerWarpSpecializeCommon(

TritonLLVMIRRewriter b(func.getLoc(), ctx);
Type int8Type = b.getIntegerType(8);
LLVM::LLVMPointerType ptrTy = ptr_ty(ctx, 3);
LLVM::LLVMPointerType ptrTy = ptr_ty(ctx, targetInfo.getSharedAddressSpace());

b.setInsertionPointToStart(switchLoop);
callbacks.reallocRegisters(b, wsOps[0], RegisterReallocPhase::SwitchLoopStart,
Expand Down
30 changes: 10 additions & 20 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2297,21 +2297,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

tt.func @gather_in_shared(%arg0: tensor<16x4xi32, #blocked1>, %arg1: tensor<8x4xf32, #blocked>) {
// CHECK-LABEL: gather_in_shared

// CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0]

// CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem
// CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]]
// CHECK: store [[S0]]
// CHECK: store
Comment thread
neildhar marked this conversation as resolved.
// CHECK-NEXT: nvvm.barrier0

// CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0]

// CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]]
// CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]]
// CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]]

// CHECK: insertvalue [[OUT0]], {{.*}}[0]
// CHECK: llvm.load [[PTR]]
// CHECK: llvm.load
// CHECK-NOT: llvm.load
// CHECK: return

%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #blocked>, tensor<16x4xi32, #blocked1>) -> tensor<16x4xf32, #blocked1>
tt.return
Expand All @@ -2329,27 +2327,19 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

tt.func @gather_in_shared_dot_input(%arg0: tensor<16x4xi32, #blocked>, %arg1: tensor<8x4xf32, #dot>) {
// CHECK-LABEL: gather_in_shared_dot_input

// CHECK: [[S0:%.*]] = llvm.extractvalue %arg1[0]
// CHECK: [[S1:%.*]] = llvm.extractvalue %arg1[1]
// CHECK: [[S2:%.*]] = llvm.extractvalue %arg1[2]
// CHECK: [[S3:%.*]] = llvm.extractvalue %arg1[3]

// CHECK: [[SMEM_BASE:%.*]] = llvm.mlir.addressof @global_smem
// CHECK-NEXT: [[SMEM:%.*]] = llvm.getelementptr [[SMEM_BASE]]
// CHECK: store [[S0]]
// CHECK: store [[S1]]
// CHECK: store [[S2]]
// CHECK: store [[S3]]
// CHECK-COUNT-4: store
// CHECK-NEXT: nvvm.barrier0

// CHECK: [[I0:%.*]] = llvm.extractvalue %arg0[0]

// CHECK: [[IDX:%.*]] = llvm.add {{.*}}, [[I0]]
// CHECK-NEXT: [[PTR:%.*]] = llvm.getelementptr [[SMEM]][[[IDX]]]
// CHECK-NEXT: [[OUT0:%.*]] = llvm.load [[PTR]]

// CHECK: insertvalue [[OUT0]], {{.*}}[0]
// CHECK: llvm.load [[PTR]]
// CHECK: llvm.load
// CHECK-NOT: llvm.load
// CHECK: return

%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<8x4xf32, #dot>, tensor<16x4xi32, #blocked>) -> tensor<16x4xf32, #blocked>
tt.return
Expand Down
Loading