Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
#include "ReduceScanCommon.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <vector>

using namespace mlir;
using namespace mlir::triton;
Expand Down Expand Up @@ -80,36 +77,16 @@ struct ReduceOpConversion
private:
const TargetInfoBase &targetInfo;

void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp,
SmallVector<Value> &acc, ValueRange cur, bool isFirst) const {
if (isFirst) {
acc = SmallVector<Value>(cur.begin(), cur.end());
return;
void accumulate(Location loc, ConversionPatternRewriter &rewriter,
Region &combineOp, SmallVector<Value> &acc, ValueRange cur,
Value pred = {}) const {
auto results = applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
if (acc.size() < results.size()) {
acc.resize(results.size());
}

// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, &parent.front());
auto &newReduce = parent.front();
auto returnOp = dyn_cast<triton::ReduceReturnOp>(newReduce.getTerminator());

llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}

rewriter.inlineBlockBefore(&newReduce, &*rewriter.getInsertionPoint(),
combineArgs);

auto results = returnOp.getResult();
for (unsigned i = 0; i < acc.size(); ++i) {
acc[i] = results[i];
}

// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
}

SmallVector<SmallVector<Value>>
Expand Down Expand Up @@ -165,7 +142,7 @@ struct ReduceOpConversion
SmallVector<unsigned> key = offsets[i];
key[op.getAxis()] = 0;
bool isFirst = accs.find(key) == accs.end();
accumulate(rewriter, *combineOp, accs[key], srcValues[i], isFirst);
accumulate(op.getLoc(), rewriter, *combineOp, accs[key], srcValues[i]);
if (isFirst)
indices[key] = srcIndices[i];
}
Expand All @@ -175,17 +152,29 @@ struct ReduceOpConversion
// region and the accumulator values as source.
void warpReduce(ConversionPatternRewriter &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce, unsigned interleave) const {
unsigned numLaneToReduce, unsigned interleave,
Value pred = {}) const {
auto success = targetInfo.warpReduce(rewriter, loc, acc, op,
numLaneToReduce, interleave);
if (success)
return;

auto mod = op->getParentOfType<ModuleOp>();
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
if (iWarpSize > numLaneToReduce) {
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(iWarpSize);
Value laneId = urem(threadId, warpSize);
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
pred = pred ? and_(pred, lanePred) : lanePred;
}

for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
SmallVector<Value> shfl(acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
shfl[i] = targetInfo.shuffleXor(rewriter, loc, acc[i], N * interleave);
}
accumulate(rewriter, op.getCombineOp(), acc, shfl, false);
accumulate(op.getLoc(), rewriter, op.getCombineOp(), acc, shfl, pred);
}
}

Expand Down Expand Up @@ -344,7 +333,8 @@ struct ReduceOpConversion
acc[i] = targetInfo.loadShared(rewriter, loc, readPtr, elemTy,
threadIsNeeded);
}
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */);
warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */,
threadIsNeeded);
// only the first thread in each sizeInterWarps is writing
Value writeOffset = readOffset;
SmallVector<Value> writePtrs(op.getNumOperands());
Expand Down
94 changes: 89 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ReduceScanCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
// TODO: refactor so that it doesn't fail if Allocation.h
// is included after utility.h (due to conflict in `store` macro
// and <atomic>
#include "triton/Analysis/Allocation.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/DialectConversion.h"

#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
//
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include <set>
#include <iterator>
#include <type_traits>

#define DEBUG_TYPE "ttgpu_to_llvm"
Expand All @@ -32,6 +31,91 @@ namespace ttng = ::mlir::triton::nvidia_gpu;
namespace mlir::triton {
class ReduceOp;
class ScanOp;

inline SmallVector<Value>
inlineCombineBlock(ConversionPatternRewriter &rewriter, Block &combineBlock,
Block *insertionBlock, Block::iterator insertionPoint,
ValueRange combineArgs) {
auto returnOp = combineBlock.getTerminator();
rewriter.inlineBlockBefore(&combineBlock, insertionBlock, insertionPoint,
combineArgs);

auto results = SmallVector<Value>(returnOp->getOperands());

// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
return results;
}

inline SmallVector<Value> applyCombineOp(Location loc,
ConversionPatternRewriter &rewriter,
Region &combineOp, ValueRange acc,
ValueRange cur, Value pred = {}) {
// Allows for passing an unitialized acc and use cur as the neutral element
if (acc.size() == 0) {
return cur;
}
assert(cur.size() == acc.size());

// Create a new copy of the combine block, and try to speculatively inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();

rewriter.cloneRegionBefore(combineOp, parent,
std::next(currentBlock->getIterator()));
Block &newCombine = *currentBlock->getNextNode();

llvm::SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}

auto isRegionSpeculatable =
std::all_of(newCombine.begin(), newCombine.end(),
[](auto &op) { return isSpeculatable(&op); });

if (!pred || isRegionSpeculatable) {
// Fast path, region has no side effects so we can unconditionally execute
return inlineCombineBlock(rewriter, newCombine, currentBlock,
rewriter.getInsertionPoint(), combineArgs);
}

// Slow case, create an if to only execute region when pred is true
// #currentBlock
// if (pred) {
// #newCombine
// results = combineOp(cur, acc)
// yield results
// } else {
// yield undef
// }
// #thenBlock
Block *thenBlock =
rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());

auto returnOp = newCombine.getTerminator();
auto results = SmallVector<Value>(returnOp->getOperands());

rewriter.setInsertionPointToEnd(currentBlock);
SmallVector<Value> thenBlockArgs;
thenBlockArgs.reserve(results.size());
for (auto result : results) {
auto ty = result.getType();
auto undef = rewriter.create<LLVM::UndefOp>(loc, ty);
thenBlockArgs.push_back(undef);
thenBlock->addArgument(ty, loc);
}
rewriter.create<cf::CondBranchOp>(loc, pred, &newCombine, combineArgs,
thenBlock, thenBlockArgs);

// Split a block after the call.
rewriter.setInsertionPointToEnd(&newCombine);
rewriter.replaceOpWithNewOp<cf::BranchOp>(returnOp, thenBlock, results);
rewriter.setInsertionPointToStart(thenBlock);
return SmallVector<Value>(thenBlock->getArguments());
}

} // namespace mlir::triton

template <typename SourceOp>
Expand Down
90 changes: 32 additions & 58 deletions lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <iterator>

#include "ReduceScanCommon.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
Expand All @@ -16,37 +14,13 @@ using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getTotalElemsPerThread;

// apply combine region to acc and cur and accumulate it into acc
// TODO(Lezcano) This is now duplicated with ReduceOpConversion::reduce.
// Deduplicate
static SmallVector<Value> accumulate(ConversionPatternRewriter &rewriter,
Region &combineOp, ValueRange acc,
ValueRange cur) {
// Allows for passing an unitialized acc and use cur as the neutral element
if (acc.size() == 0) {
return cur;
}
assert(cur.size() == acc.size());
// Create a new copy of the reduce block, and inline it
Block *currentBlock = rewriter.getBlock();
Region &parent = *currentBlock->getParent();
rewriter.cloneRegionBefore(combineOp, &parent.front());
auto &newScan = parent.front();
auto returnOp = dyn_cast<triton::ScanReturnOp>(newScan.getTerminator());

SmallVector<Value> combineArgs(2 * acc.size());
for (unsigned i = 0; i < acc.size(); ++i) {
combineArgs[i] = acc[i];
combineArgs[acc.size() + i] = cur[i];
}

rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(),
combineArgs);
SmallVector<Value> results;
llvm::transform(returnOp.getResult(), std::back_inserter(results),
[&](Value res) { return rewriter.getRemappedValue(res); });
// Delete the terminator, which is no longer used
rewriter.eraseOp(returnOp);
return results;
static SmallVector<Value> accumulate(ScanLoweringHelper &helper,
ConversionPatternRewriter &rewriter,
ValueRange acc, ValueRange cur,
Value pred = {}) {
auto loc = helper.getLoc();
auto &combineOp = helper.getCombineOp();
return applyCombineOp(loc, rewriter, combineOp, acc, cur, pred);
}

// Scan a contiguous elements within a thread and update `srcValues` in place.
Expand All @@ -66,8 +40,8 @@ scanThreadContiguousElements(SmallVector<SmallVector<Value>> &srcValues,
unsigned accIndex = (srcIndex % stride) +
((srcIndex / stride) / scanElementsPerThreads) * stride;

accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex],
srcValues[srcIndex]);
accs[accIndex] =
accumulate(helper, rewriter, accs[accIndex], srcValues[srcIndex]);
srcValues[srcIndex] = accs[accIndex];
}
}
Expand Down Expand Up @@ -95,11 +69,11 @@ static void warpScan(SmallVector<SmallVector<Value>> &srcValues,
for (unsigned j = 0; j < acc.size(); ++j) {
shfl[j] = targetInfo.shuffleUp(rewriter, loc, acc[j], i * threadStride);
}
Value mask = icmp_sge(laneIdAxis, i32_val(i));
SmallVector<Value> tempAcc =
accumulate(rewriter, helper.getCombineOp(), shfl, acc);
Value mask = icmp_slt(laneIdAxis, i32_val(i));
accumulate(helper, rewriter, shfl, acc, mask);
for (unsigned j = 0; j < acc.size(); ++j) {
acc[j] = select(mask, acc[j], tempAcc[j]);
acc[j] = select(mask, tempAcc[j], acc[j]);
}
}
srcValues[srcIndex] = acc;
Expand Down Expand Up @@ -164,9 +138,9 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
Value maskNotFirstWarp = icmp_ne(warpId, i32_val(0));
Value maskNotFirstLane = icmp_ne(laneIdAxis, i32_val(0));
Value maskNotFirstThread = or_(maskNotFirstWarp, maskNotFirstLane);
struct Accumulator {
SmallVector<Value> acc;
SmallVector<Value> maskedAcc;
Expand Down Expand Up @@ -212,42 +186,43 @@ static void AddPartialReduce(SmallVector<SmallVector<Value>> &srcValues,
accumulator.maskedAcc = partialReduce;
continue;
}
accumulator.acc = accumulate(rewriter, helper.getCombineOp(),
accumulator.acc, partialReduce);
Value mask = icmp_slt(warpId, i32_val(i + 1));
Value mask = icmp_sge(warpId, i32_val(i + 1));
accumulator.acc =
accumulate(helper, rewriter, accumulator.acc, partialReduce, mask);
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
accumulator.maskedAcc[j] =
select(mask, accumulator.maskedAcc[j], accumulator.acc[j]);
select(mask, accumulator.acc[j], accumulator.maskedAcc[j]);
}
}
auto temp = accumulate(rewriter, helper.getCombineOp(),
accumulator.maskedAcc, srcValues[srcIndex]);

Value pred = axisBlockId == 0 ? maskNotFirstWarp : Value{};
auto temp = accumulate(helper, rewriter, accumulator.maskedAcc,
srcValues[srcIndex], pred);
if (axisBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
auto val = srcValues[srcIndex];
for (unsigned i = 0; i < helper.getNumOperands(); ++i) {
temp[i] = select(maskFirstWarp, val[i], temp[i]);
temp[i] = select(maskNotFirstWarp, temp[i], val[i]);
}
}
srcValues[srcIndex] = temp;
// Update the rest of the contiguous elements.
SmallVector<Value> lastElement(helper.getNumOperands());
for (unsigned i = 0; i < helper.getNumOperands(); ++i) {
auto elem = targetInfo.shuffleUp(rewriter, loc, temp[i], threadStride);
lastElement[i] = select(maskFirstLane, accumulator.maskedAcc[i], elem);
lastElement[i] = select(maskNotFirstLane, elem, accumulator.maskedAcc[i]);
}
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
pred = axisBlockId == 0 ? maskNotFirstThread : Value{};
auto laneValue = srcValues[srcIndex - i * elementStride];
laneValue =
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
laneValue = accumulate(helper, rewriter, lastElement, laneValue, pred);
if (axisBlockId == 0) {
// For the first warp and first chunk we don't have anything to
// accumulate.
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
laneValue[j] =
select(maskFirstThread,
srcValues[srcIndex - i * elementStride][j], laneValue[j]);
laneValue[j] = select(maskNotFirstThread, laneValue[j],
srcValues[srcIndex - i * elementStride][j]);
}
}
srcValues[srcIndex - i * elementStride] = laneValue;
Expand Down Expand Up @@ -300,8 +275,8 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
if (axisBlockId == 0) // First chunk and first block
accumulator = srcValues[srcIndex];
else
srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(),
accumulator, srcValues[srcIndex]);
srcValues[srcIndex] =
accumulate(helper, rewriter, accumulator, srcValues[srcIndex]);
// Update the rest of the contiguous elements.
auto lastElement = srcValues[srcIndex];
if (scanDim > 1) {
Expand All @@ -319,8 +294,7 @@ static void AddPartialReduceOneWarp(SmallVector<SmallVector<Value>> &srcValues,
}
for (unsigned i = 1; i < scanElementsPerThreads; ++i) {
auto laneValue = srcValues[srcIndex - i * elementStride];
laneValue =
accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue);
laneValue = accumulate(helper, rewriter, lastElement, laneValue);
if (axisBlockId == 0) {
for (unsigned j = 0; j < helper.getNumOperands(); ++j) {
// For the first warp and first chunk we don't have anything to
Expand Down
Loading