diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 34ac8fe784..f9a7947f93 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -160,9 +160,12 @@ def make_ttgir(mod, metadata, opt, device_arch): pm.enable_debug() passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas) + is_lts_driver = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642") + enable_remat_cache = not is_lts_driver + # optimize TTGIR intel.passes.ttgpuir.add_accelerate_matmul(pm) - intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache) intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) # FIXME: Use a better way to check if prefetch instructions are supported once available. # Prefetch instruction is not available in older drivers. @@ -170,13 +173,13 @@ def make_ttgir(mod, metadata, opt, device_arch): intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) passes.ttgpuir.add_coalesce(pm) - intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache) passes.ttgpuir.add_optimize_thread_locality(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.common.add_cse(pm) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) - intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache) passes.ttgpuir.add_reduce_data_duplication(pm) passes.ttgpuir.add_reorder_instructions(pm) passes.common.add_cse(pm) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 613732bfdf..7535905b79 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -156,6 +156,10 @@ def TritonIntelGPURemoveLayoutConversions : Pass<"tritonintelgpu-remove-layout-c let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonDialect"]; + let options = [ + Option<"enableRematCache", "enable-remat-cache", "bool", /*default*/"true", "Whether to enable the rematerialization cache (currently causing problems with the LTS driver).">, + ]; + } def TritonIntelGPURewriteTensorPointer : Pass<"tritonintelgpu-rewrite-tensor-pointer", "mlir::ModuleOp"> { diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 87087074f3..24ca2ef748 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -165,14 +165,16 @@ class LayoutRematerialization { return it->second; } void cleanup(); - void backwardRematerialization(); - void backwardRematerialization(ConvertLayoutOp convertOp); + void backwardRematerialization(bool enableRematCache); + void backwardRematerialization(ConvertLayoutOp convertOp, + bool enableRematCache); void hoistConvertOnTopOfExtOrBroadcast(); void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); void rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp, IRMapping &mapping); + ConvertLayoutOp convertOp, IRMapping &mapping, + bool enableRematCache); void rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp); + ConvertLayoutOp convertOp, bool enableRematCache); private: void updateRematMapping(SmallVector> &values); @@ -944,7 +946,8 @@ void LayoutRematerialization::updateRematMapping( void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, - IRMapping &mapping) { + IRMapping &mapping, + bool enableRematCache) { SetVector opsToRewrite; // Keep track of yield operands that need to be duplicated. DenseMap> yieldOperandsMap; @@ -952,7 +955,7 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. - if (hasRematValue(v, layoutIt->second)) { + if (enableRematCache && hasRematValue(v, layoutIt->second)) { mapping.map(v, getRematValue(v, layoutIt->second)); continue; } @@ -1118,9 +1121,10 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp) { + ConvertLayoutOp convertOp, + bool enableRematCache) { IRMapping mapping; - rewriteSlice(slice, layout, convertOp, mapping); + rewriteSlice(slice, layout, convertOp, mapping, enableRematCache); } LogicalResult getRematerializableSlice( @@ -1142,13 +1146,13 @@ LogicalResult getRematerializableSlice( return success(); } -void LayoutRematerialization::backwardRematerialization() { +void LayoutRematerialization::backwardRematerialization(bool enableRematCache) { // Go through each ConvertLayoutOp. SmallVector convertOps; funcOp.walk( [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { - backwardRematerialization(convertOp); + backwardRematerialization(convertOp, enableRematCache); } } @@ -1163,7 +1167,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { } void LayoutRematerialization::backwardRematerialization( - ConvertLayoutOp convertOp) { + ConvertLayoutOp convertOp, bool enableRematCache) { RankedTensorType targetType = convertOp.getType(); // We don't backward propagate the dot layout with blocked layout as parent. // It introduces a lot of duplicated values in multiple-threads. @@ -1202,7 +1206,7 @@ void LayoutRematerialization::backwardRematerialization( DBGS() << " " << v << '\n'; }); // 2. Rewrite the slice. - rewriteSlice(slice, layout, convertOp); + rewriteSlice(slice, layout, convertOp, enableRematCache); } // For convert left we try to hoist them above type extension to reduce the cost @@ -1286,13 +1290,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); slice.remove(extOrBroadcatOp->getResult(0)); // 3. Rewrite the slice. - rewriteSlice(slice, layout, convertOp, mapping); + rewriteSlice(slice, layout, convertOp, mapping, /*enableRematCache=*/true); } -void backwardRematerialization(ModuleOp module) { - module.walk([](FuncOp funcOp) { +void backwardRematerialization(ModuleOp module, bool enableRematCache) { + module.walk([enableRematCache](FuncOp funcOp) { LayoutRematerialization layoutRemat(funcOp); - layoutRemat.backwardRematerialization(); + layoutRemat.backwardRematerialization(enableRematCache); layoutRemat.cleanup(); }); } @@ -1310,6 +1314,9 @@ class TritonIntelGPURemoveLayoutConversionsPass : public intel::impl::TritonIntelGPURemoveLayoutConversionsBase< TritonIntelGPURemoveLayoutConversionsPass> { public: + using TritonIntelGPURemoveLayoutConversionsBase:: + TritonIntelGPURemoveLayoutConversionsBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); @@ -1341,7 +1348,7 @@ class TritonIntelGPURemoveLayoutConversionsPass // 2. For remaining convert ops, try to rematerialize the slice of producer // operation to avoid having to convert. - backwardRematerialization(m); + backwardRematerialization(m, enableRematCache); LLVM_DEBUG({ DBGS() << "Module after backward remat:\n"; m.dump(); diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 4dfd892463..be826d35c1 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -1,4 +1,4 @@ -#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassManager.h" #include "passes.h" #include "llvm/IRReader/IRReader.h" @@ -65,8 +65,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createIntelAllocateSharedMemory); ADD_PASS_WRAPPER_OPT_2("add_pipeline", gpu::intel::createTritonIntelGPUPipeline, int, bool); - ADD_PASS_WRAPPER_0("add_remove_layout_conversions", - gpu::intel::createTritonIntelGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_OPT_1( + "add_remove_layout_conversions", + gpu::intel::createTritonIntelGPURemoveLayoutConversions, bool); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", gpu::intel::createTritonIntelGPURewriteTensorPointer); ADD_PASS_WRAPPER_OPT_2("add_prefetch_block",