Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/triton/Dialect/TritonGPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,20 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init"
"mlir::triton::TritonDialect"];
}

def TritonGPUWSTaskPartition : Pass<"tritongpu-warp-spec-task-partition", "mlir::ModuleOp"> {
let summary = "Warp specialization task partition";

let description = "This pass computes a warp schedule partition by annoating anchor operations with async task ids";

let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"];
let options = [
Option<"numConsumerGroups", "num-consumer-groups",
"int32_t", /*default*/"0",
"number of consumer warp groups for warp specialization">
];
}

def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> {
let summary = "Propagate async_task_id annotations based on dependencies";

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TritonGPU/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_triton_library(TritonGPUTransforms
ReorderInstructions.cpp
Utility.cpp
TaskIdPropagate.cpp
WSTaskPartition.cpp
WSDataPartition.cpp
WSCodePartition.cpp
WSLowering.cpp
Expand Down
89 changes: 82 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
Expand Down Expand Up @@ -169,15 +170,15 @@ struct AsyncTaskIdsCompare {

// Make sure the def chain contains the right taskId.
bool verifyTaskId(triton::FuncOp &funcOp,
const llvm::DenseSet<Operation *>& anchorOps) {
const llvm::DenseSet<Operation *> &anchorOps) {
bool retCode = true;
DenseSet<SmallVector<AsyncTaskId>, AsyncTaskIdsCompare> anchorAsyncTasks;
for (auto anchorOp : anchorOps) {
anchorAsyncTasks.insert(getAsyncTaskIds(anchorOp));
}

funcOp.walk([&](Operation *op) {
// Skip control ops
// Skip control ops
if (llvm::isa<ReturnOp, FuncOp, scf::YieldOp, scf::ForOp>(op))
return;

Expand Down Expand Up @@ -315,12 +316,63 @@ void backwardPropagateTaskIds(Operation *op,
}
}

void backwardPropagateTaskIds(llvm::DenseSet<Operation *> &anchorOps) {
for (Operation *op : anchorOps) {
void backwardPropagateTaskIds(llvm::DenseSet<Operation *> &rootOps,
llvm::DenseSet<Operation *> &anchorOps) {
for (Operation *op : rootOps) {
backwardPropagateTaskIds(op, anchorOps);
}
}

void forwardPropagateTaskIds(Operation *root,
const llvm::DenseSet<Operation *> &anchors) {
auto asyncTasks = getAsyncTaskIds(root);
SmallVector<Value> queue;
for (Value result : root->getResults())
queue.push_back(result);

DenseSet<Value> seen;
for (auto anchor : anchors) {
if (anchor != root)
for (auto result : anchor->getResults())
seen.insert(result);
}

while (!queue.empty()) {
auto v = queue.back();
queue.pop_back();
if (!seen.insert(v).second)
continue;

for (Operation *depOp : v.getUsers()) {
auto depAsyncTasks = getAsyncTaskIds(depOp);
// Skip depOp that already has task ids. Those could be either anchorOps
// or propagated backward from anchor ops.
if (!depAsyncTasks.empty() && depAsyncTasks != asyncTasks)
continue;
setAsyncTaskIds(depOp, asyncTasks);
// Go through yieldOp to propagate task ids to the result of parentOp.
if (auto yieldOp = dyn_cast<scf::YieldOp>(depOp)) {
auto parentOp = yieldOp->getParentOp();
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (operand.get() == v) {
queue.push_back(parentOp->getResult(operand.getOperandNumber()));
break;
}
}
} else {
for (Value result : depOp->getResults())
queue.push_back(result);
}
}
}
}

void forwardPropagateTaskIds(llvm::DenseSet<Operation *> &anchorOps) {
for (Operation *op : anchorOps) {
forwardPropagateTaskIds(op, anchorOps);
}
}

void populateTaskIdsForControlDependencies(
llvm::DenseSet<Operation *> &anchorOps) {
for (auto op : anchorOps) {
Expand Down Expand Up @@ -349,8 +401,11 @@ class TritonGPUTaskIdPropagatePass
llvm::DenseSet<Operation *> anchorOps;
funcOp.walk([&](mlir::Operation *op) {
auto asyncTasks = getAsyncTaskIds(op);
if (!asyncTasks.empty() &&
!isa<arith::ConstantOp, arith::ConstantIntOp>(op))
if (asyncTasks.empty())
return;
std::sort(asyncTasks.begin(), asyncTasks.end());
setAsyncTaskIds(op, asyncTasks);
if (!isa<arith::ConstantOp, arith::ConstantIntOp>(op))
anchorOps.insert(op);
});

Expand All @@ -361,13 +416,33 @@ class TritonGPUTaskIdPropagatePass
funcOp->dump();
});

backwardPropagateTaskIds(anchorOps);
backwardPropagateTaskIds(anchorOps, anchorOps);

LLVM_DEBUG({
LDBG("after backwardPropagateTaskIds ");
funcOp->dump();
});

forwardPropagateTaskIds(anchorOps);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious on why we need to forward then backward, is it needed because of auto annotation?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally backward and forward should be in a loop until no new ops gets a new task id. In practice, the backward->forward->backward pattern works well, similar to what has been done in data partition, because the forward pass is recursive.

Forward is not needed with user annotation as user will always annotate the final store.


LLVM_DEBUG({
LDBG("after forwardPropagateTaskIds ");
funcOp->dump();
});

llvm::DenseSet<Operation *> rootOps;
funcOp.walk([&](mlir::Operation *op) {
auto asyncTasks = getAsyncTaskIds(op);
if (!asyncTasks.empty() &&
!isa<arith::ConstantOp, arith::ConstantIntOp>(op))
rootOps.insert(op);
});
backwardPropagateTaskIds(rootOps, anchorOps);
LLVM_DEBUG({
LDBG("after final backwardPropagateTaskIds ");
funcOp->dump();
});

DenseSet<int> allAsyncTasks;
funcOp->walk([&](Operation *op) {
auto asyncTasks = getAsyncTaskIds(op);
Expand Down
168 changes: 168 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/WSTaskPartition.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"

namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = ::mlir::triton::nvidia_gpu;
namespace mlir {
namespace triton {
namespace gpu {

#define DEBUG_TYPE "tritongpu-warp-task-partition"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

#define GEN_PASS_DEF_TRITONGPUWSTASKPARTITION
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"

struct TaskSchedule {
unsigned numTasks = 0;
DenseMap<Operation *, unsigned> opToTaskId;
};

// Compute a partition schedule for later passes to actually partition the
// program into async tasks.
void doPartition(triton::FuncOp &funcOp, unsigned numConsumerGroups) {

// Bail out in the presence of user annotations.
DenseSet<int> allAsyncTasks;
funcOp->walk([&](Operation *op) {
auto asyncTasks = getAsyncTaskIds(op);
allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end());
});

if (!allAsyncTasks.empty())
return;

SmallVector<scf::ForOp> loops;
SmallVector<Operation *> loads;
SmallVector<Operation *> dots;

funcOp.walk([&](Operation *op) {
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op))
loops.push_back(forOp);
else if (isa<nvidia_gpu::WarpGroupDotOp>(op))
dots.push_back(op);
else if (isa<triton::LoadOp, ExperimentalDescriptorLoadOp>(op))
loads.push_back(op);
});

if (loops.empty() || loads.empty() || dots.empty())
return;

auto getLoopLevel = [&](Operation *op) {
// Compute loop depth
unsigned depth = 0;
Operation *parent = op->getParentOp();
while (parent) {
if (isa<scf::ForOp>(parent)) {
++depth;
}
parent = parent->getParentOp();
}
return depth;
};

// Step 1. Select loads into the first task, which is the producer task by
// default. Place dots into the second task, which is the consumer.
// Only consider loads that are connected to a dot op in a loop.
SmallVector<Operation *> producerOps;
SmallVector<Operation *> consumerOps;
for (auto op : dots) {
if (getLoopLevel(op) == 0)
continue;
consumerOps.push_back(op);
auto dotOp = dyn_cast<nvidia_gpu::WarpGroupDotOp>(op);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be WarpGroupDot always?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean other mma operations? Ideally the heuristic should be able to handle all expansive ops, but for now we are limiting it to handling only wgmma, as that's the only warp-group level ops we have so far. Regular mma may not need a warp-group level of warp specialization.

if (!dotOp)
continue;
SetVector<Operation *> backwardSlice;
getBackwardSlice(dotOp.getA(), &backwardSlice);
getBackwardSlice(dotOp.getB(), &backwardSlice);

for (auto depOp : backwardSlice) {
if (isa<triton::LoadOp, ExperimentalDescriptorLoadOp>(depOp)) {
producerOps.push_back(depOp);
}
}
}

LLVM_DEBUG({
LDBG("Producer ops:\n");
for (auto op : producerOps) {
op->dump();
}

LDBG("\n");
LDBG("Consumer ops:\n");
for (auto op : consumerOps) {
op->dump();
}

LDBG("\n");
});

if (consumerOps.empty() || producerOps.empty())
return;

// Annoate the program with task ids
SmallVector<AsyncTaskId, 1> producerTaskIds{0};
SmallVector<AsyncTaskId, 2> consumerTaskIds;
for (unsigned i = 0; i < numConsumerGroups; ++i) {
consumerTaskIds.push_back(i + producerTaskIds.size());
}

for (auto op : producerOps) {
setAsyncTaskIds(op, producerTaskIds);
}

for (auto op : consumerOps) {
setAsyncTaskIds(op, consumerTaskIds);
}

LLVM_DEBUG({
LDBG("After task partition");
funcOp.dump();
LDBG("\n");
});
}

class TritonGPUWSTaskPartitionPass
: public impl::TritonGPUWSTaskPartitionBase<TritonGPUWSTaskPartitionPass> {
public:
using impl::TritonGPUWSTaskPartitionBase<
TritonGPUWSTaskPartitionPass>::TritonGPUWSTaskPartitionBase;

void runOnFuncOp(triton::FuncOp funcOp) {
if (numConsumerGroups == 0)
return;
doPartition(funcOp, numConsumerGroups);
}

void runOnOperation() override {
getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); });
}
};

} // namespace gpu
} // namespace triton
} // namespace mlir
2 changes: 2 additions & 0 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ void init_triton_passes_ttgpuir(py::module &&m) {
createTritonGPUCombineTensorSelectAndIf);
ADD_PASS_WRAPPER_0("add_optimize_accumulator_init",
createTritonGPUOptimizeAccumulatorInit);
ADD_PASS_OPTION_WRAPPER_1("add_ws_task_partition",
createTritonGPUWSTaskPartition, int);
ADD_PASS_OPTION_WRAPPER_1("add_ws_data_partition",
createTritonGPUWSDataPartition, int);
ADD_PASS_OPTION_WRAPPER_1("add_ws_lowering", createTritonGPUWSLowering, int);
Expand Down
Loading