-
Notifications
You must be signed in to change notification settings - Fork 52
Automating warp-group partition #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,168 @@ | ||
| #include "mlir/Analysis/SliceAnalysis.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/IR/BuiltinAttributes.h" | ||
| #include "mlir/IR/Dominance.h" | ||
| #include "mlir/IR/IRMapping.h" | ||
| #include "mlir/IR/Matchers.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/IR/Verifier.h" | ||
| #include "mlir/Interfaces/InferTypeOpInterface.h" | ||
| #include "mlir/Pass/Pass.h" | ||
| #include "mlir/Pass/PassManager.h" | ||
| #include "mlir/Support/LLVM.h" | ||
| #include "mlir/Support/LogicalResult.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
| #include "mlir/Transforms/Passes.h" | ||
| #include "mlir/Transforms/RegionUtils.h" | ||
| #include "triton/Analysis/Utility.h" | ||
| #include "triton/Dialect/Triton/IR/Utility.h" | ||
| #include "triton/Dialect/TritonGPU/IR/Dialect.h" | ||
| #include "triton/Dialect/TritonGPU/Transforms/Passes.h" | ||
| #include "triton/Dialect/TritonGPU/Transforms/Utility.h" | ||
| #include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" | ||
|
|
||
| namespace tt = mlir::triton; | ||
| namespace ttg = mlir::triton::gpu; | ||
| namespace ttng = ::mlir::triton::nvidia_gpu; | ||
| namespace mlir { | ||
| namespace triton { | ||
| namespace gpu { | ||
|
|
||
| #define DEBUG_TYPE "tritongpu-warp-task-partition" | ||
| #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") | ||
| #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") | ||
|
|
||
| #define GEN_PASS_DEF_TRITONGPUWSTASKPARTITION | ||
| #include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" | ||
|
|
||
| struct TaskSchedule { | ||
| unsigned numTasks = 0; | ||
| DenseMap<Operation *, unsigned> opToTaskId; | ||
| }; | ||
|
|
||
| // Compute a partition schedule for later passes to actually partition the | ||
| // program into async tasks. | ||
| void doPartition(triton::FuncOp &funcOp, unsigned numConsumerGroups) { | ||
|
|
||
| // Bail out in the presence of user annotations. | ||
| DenseSet<int> allAsyncTasks; | ||
| funcOp->walk([&](Operation *op) { | ||
| auto asyncTasks = getAsyncTaskIds(op); | ||
| allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); | ||
| }); | ||
|
|
||
| if (!allAsyncTasks.empty()) | ||
| return; | ||
|
|
||
| SmallVector<scf::ForOp> loops; | ||
| SmallVector<Operation *> loads; | ||
| SmallVector<Operation *> dots; | ||
|
|
||
| funcOp.walk([&](Operation *op) { | ||
| if (scf::ForOp forOp = dyn_cast<scf::ForOp>(op)) | ||
| loops.push_back(forOp); | ||
| else if (isa<nvidia_gpu::WarpGroupDotOp>(op)) | ||
| dots.push_back(op); | ||
| else if (isa<triton::LoadOp, ExperimentalDescriptorLoadOp>(op)) | ||
| loads.push_back(op); | ||
| }); | ||
|
|
||
| if (loops.empty() || loads.empty() || dots.empty()) | ||
| return; | ||
|
|
||
| auto getLoopLevel = [&](Operation *op) { | ||
| // Compute loop depth | ||
| unsigned depth = 0; | ||
| Operation *parent = op->getParentOp(); | ||
| while (parent) { | ||
| if (isa<scf::ForOp>(parent)) { | ||
| ++depth; | ||
| } | ||
| parent = parent->getParentOp(); | ||
| } | ||
| return depth; | ||
| }; | ||
|
|
||
| // Step 1. Select loads into the first task, which is the producer task by | ||
| // default. Place dots into the second task, which is the consumer. | ||
| // Only consider loads that are connected to a dot op in a loop. | ||
| SmallVector<Operation *> producerOps; | ||
| SmallVector<Operation *> consumerOps; | ||
| for (auto op : dots) { | ||
| if (getLoopLevel(op) == 0) | ||
| continue; | ||
| consumerOps.push_back(op); | ||
| auto dotOp = dyn_cast<nvidia_gpu::WarpGroupDotOp>(op); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should it be WarpGroupDot always?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean other mma operations? Ideally the heuristic should be able to handle all expansive ops, but for now we are limiting it to handling only wgmma, as that's the only warp-group level ops we have so far. Regular mma may not need a warp-group level of warp specialization. |
||
| if (!dotOp) | ||
| continue; | ||
| SetVector<Operation *> backwardSlice; | ||
| getBackwardSlice(dotOp.getA(), &backwardSlice); | ||
| getBackwardSlice(dotOp.getB(), &backwardSlice); | ||
|
|
||
| for (auto depOp : backwardSlice) { | ||
| if (isa<triton::LoadOp, ExperimentalDescriptorLoadOp>(depOp)) { | ||
| producerOps.push_back(depOp); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| LLVM_DEBUG({ | ||
| LDBG("Producer ops:\n"); | ||
| for (auto op : producerOps) { | ||
| op->dump(); | ||
| } | ||
|
|
||
| LDBG("\n"); | ||
| LDBG("Consumer ops:\n"); | ||
| for (auto op : consumerOps) { | ||
| op->dump(); | ||
| } | ||
|
|
||
| LDBG("\n"); | ||
| }); | ||
|
|
||
| if (consumerOps.empty() || producerOps.empty()) | ||
| return; | ||
|
|
||
| // Annoate the program with task ids | ||
| SmallVector<AsyncTaskId, 1> producerTaskIds{0}; | ||
| SmallVector<AsyncTaskId, 2> consumerTaskIds; | ||
| for (unsigned i = 0; i < numConsumerGroups; ++i) { | ||
| consumerTaskIds.push_back(i + producerTaskIds.size()); | ||
| } | ||
|
|
||
| for (auto op : producerOps) { | ||
| setAsyncTaskIds(op, producerTaskIds); | ||
| } | ||
|
|
||
| for (auto op : consumerOps) { | ||
| setAsyncTaskIds(op, consumerTaskIds); | ||
| } | ||
|
|
||
| LLVM_DEBUG({ | ||
| LDBG("After task partition"); | ||
| funcOp.dump(); | ||
| LDBG("\n"); | ||
| }); | ||
| } | ||
|
|
||
| class TritonGPUWSTaskPartitionPass | ||
| : public impl::TritonGPUWSTaskPartitionBase<TritonGPUWSTaskPartitionPass> { | ||
| public: | ||
| using impl::TritonGPUWSTaskPartitionBase< | ||
| TritonGPUWSTaskPartitionPass>::TritonGPUWSTaskPartitionBase; | ||
|
|
||
| void runOnFuncOp(triton::FuncOp funcOp) { | ||
| if (numConsumerGroups == 0) | ||
| return; | ||
| doPartition(funcOp, numConsumerGroups); | ||
| } | ||
|
|
||
| void runOnOperation() override { | ||
| getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace gpu | ||
| } // namespace triton | ||
| } // namespace mlir | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious on why we need to forward then backward, is it needed because of auto annotation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally backward and forward should be in a loop until no new ops gets a new task id. In practice, the backward->forward->backward pattern works well, similar to what has been done in data partition, because the forward pass is recursive.
Forward is not needed with user annotation as user will always annotate the final store.