Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,26 @@ def make_ttgir(mod, metadata, opt, device_arch):
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas)

is_lts_driver = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Be consistent with string quote character within a file.

enable_remat_cache = not is_lts_driver

# optimize TTGIR
intel.passes.ttgpuir.add_accelerate_matmul(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache)
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
# FIXME: Use a better way to check if prefetch instructions are supported once available.
# Prefetch instruction is not available in older drivers.
if Version(metadata["target"].arch['driver_version']) > Version("1.3.28202"):
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)

passes.ttgpuir.add_coalesce(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
passes.common.add_cse(pm)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, True)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
passes.common.add_cse(pm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def TritonIntelGPURemoveLayoutConversions : Pass<"tritonintelgpu-remove-layout-c
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::triton::TritonDialect"];

let options = [
Option<"enableRematCache", "enable-remat-cache", "bool", /*default*/"true", "Whether to enable the rematerialization cache (currently causing problems with the LTS driver).">,
];

}

def TritonIntelGPURewriteTensorPointer : Pass<"tritonintelgpu-rewrite-tensor-pointer", "mlir::ModuleOp"> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,16 @@ class LayoutRematerialization {
return it->second;
}
void cleanup();
void backwardRematerialization();
void backwardRematerialization(ConvertLayoutOp convertOp);
void backwardRematerialization(bool enableRematCache);
void backwardRematerialization(ConvertLayoutOp convertOp,
bool enableRematCache);
void hoistConvertOnTopOfExtOrBroadcast();
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp, IRMapping &mapping);
ConvertLayoutOp convertOp, IRMapping &mapping,
bool enableRematCache);
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp);
ConvertLayoutOp convertOp, bool enableRematCache);

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
Expand Down Expand Up @@ -944,15 +946,16 @@ void LayoutRematerialization::updateRematMapping(
void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
IRMapping &mapping) {
IRMapping &mapping,
bool enableRematCache) {
SetVector<Operation *> opsToRewrite;
// Keep track of yield operands that need to be duplicated.
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
for (Value v : slice) {
auto layoutIt = layout.find(v);
assert(layoutIt != layout.end());
// If we already have a remat value for this value, use it.
if (hasRematValue(v, layoutIt->second)) {
if (enableRematCache && hasRematValue(v, layoutIt->second)) {
mapping.map(v, getRematValue(v, layoutIt->second));
continue;
}
Expand Down Expand Up @@ -1118,9 +1121,10 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp) {
ConvertLayoutOp convertOp,
bool enableRematCache) {
IRMapping mapping;
rewriteSlice(slice, layout, convertOp, mapping);
rewriteSlice(slice, layout, convertOp, mapping, enableRematCache);
}

LogicalResult getRematerializableSlice(
Expand All @@ -1142,13 +1146,13 @@ LogicalResult getRematerializableSlice(
return success();
}

void LayoutRematerialization::backwardRematerialization() {
void LayoutRematerialization::backwardRematerialization(bool enableRematCache) {
// Go through each ConvertLayoutOp.
SmallVector<ConvertLayoutOp> convertOps;
funcOp.walk(
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
for (ConvertLayoutOp convertOp : convertOps) {
backwardRematerialization(convertOp);
backwardRematerialization(convertOp, enableRematCache);
}
}

Expand All @@ -1163,7 +1167,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
}

void LayoutRematerialization::backwardRematerialization(
ConvertLayoutOp convertOp) {
ConvertLayoutOp convertOp, bool enableRematCache) {
RankedTensorType targetType = convertOp.getType();
// We don't backward propagate the dot layout with blocked layout as parent.
// It introduces a lot of duplicated values in multiple-threads.
Expand Down Expand Up @@ -1202,7 +1206,7 @@ void LayoutRematerialization::backwardRematerialization(
DBGS() << " " << v << '\n';
});
// 2. Rewrite the slice.
rewriteSlice(slice, layout, convertOp);
rewriteSlice(slice, layout, convertOp, enableRematCache);
}

// For convert left we try to hoist them above type extension to reduce the cost
Expand Down Expand Up @@ -1286,13 +1290,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0));
slice.remove(extOrBroadcatOp->getResult(0));
// 3. Rewrite the slice.
rewriteSlice(slice, layout, convertOp, mapping);
rewriteSlice(slice, layout, convertOp, mapping, /*enableRematCache=*/true);
}

void backwardRematerialization(ModuleOp module) {
module.walk([](FuncOp funcOp) {
void backwardRematerialization(ModuleOp module, bool enableRematCache) {
module.walk([enableRematCache](FuncOp funcOp) {
LayoutRematerialization layoutRemat(funcOp);
layoutRemat.backwardRematerialization();
layoutRemat.backwardRematerialization(enableRematCache);
layoutRemat.cleanup();
});
}
Expand All @@ -1310,6 +1314,9 @@ class TritonIntelGPURemoveLayoutConversionsPass
: public intel::impl::TritonIntelGPURemoveLayoutConversionsBase<
TritonIntelGPURemoveLayoutConversionsPass> {
public:
using TritonIntelGPURemoveLayoutConversionsBase::
TritonIntelGPURemoveLayoutConversionsBase;

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
Expand Down Expand Up @@ -1341,7 +1348,7 @@ class TritonIntelGPURemoveLayoutConversionsPass

// 2. For remaining convert ops, try to rematerialize the slice of producer
// operation to avoid having to convert.
backwardRematerialization(m);
backwardRematerialization(m, enableRematCache);
LLVM_DEBUG({
DBGS() << "Module after backward remat:\n";
m.dump();
Expand Down
7 changes: 4 additions & 3 deletions third_party/intel/triton_xpu.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassManager.h"
#include "passes.h"

#include "llvm/IRReader/IRReader.h"
Expand Down Expand Up @@ -65,8 +65,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) {
gpu::intel::createIntelAllocateSharedMemory);
ADD_PASS_WRAPPER_OPT_2("add_pipeline",
gpu::intel::createTritonIntelGPUPipeline, int, bool);
ADD_PASS_WRAPPER_0("add_remove_layout_conversions",
gpu::intel::createTritonIntelGPURemoveLayoutConversions);
ADD_PASS_WRAPPER_OPT_1(
"add_remove_layout_conversions",
gpu::intel::createTritonIntelGPURemoveLayoutConversions, bool);
ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer",
gpu::intel::createTritonIntelGPURewriteTensorPointer);
ADD_PASS_WRAPPER_OPT_2("add_prefetch_block",
Expand Down