From f63979819b907d66377191e8c1a5728df0cbc15b Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 6 Jun 2024 10:55:56 -0700 Subject: [PATCH 1/5] [TMP]: Disable layout rematerialization cache --- .../lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 87087074f3..e2df9649d5 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -952,10 +952,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. + #if 0 // FIXME: Fails on LTS driver if (hasRematValue(v, layoutIt->second)) { mapping.map(v, getRematValue(v, layoutIt->second)); continue; } + #endif if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); if (auto ifOp = v.getDefiningOp()) { From 58b674d5c75b95bfd511f3eac39513be9a5d6d9f Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Thu, 6 Jun 2024 12:32:31 -0700 Subject: [PATCH 2/5] fix formatting --- .../lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index e2df9649d5..33c5ee3577 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -952,12 +952,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. - #if 0 // FIXME: Fails on LTS driver +#if 0 // FIXME: Fails on LTS driver if (hasRematValue(v, layoutIt->second)) { mapping.map(v, getRematValue(v, layoutIt->second)); continue; } - #endif +#endif if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); if (auto ifOp = v.getDefiningOp()) { From 42610aa4541d6feb3d1c45cf55b92e9130ac8ef1 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 7 Jun 2024 09:35:50 -0700 Subject: [PATCH 3/5] [WIP]: Conditionally disable remat cache --- third_party/intel/backend/compiler.py | 9 +++-- .../TritonIntelGPU/Transforms/Passes.td | 4 ++ .../RemoveLayoutConversions.cpp | 40 ++++++++++--------- third_party/intel/triton_xpu.cc | 6 +-- 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 34ac8fe784..c2cf725c9a 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -160,9 +160,12 @@ def make_ttgir(mod, metadata, opt, device_arch): pm.enable_debug() passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas) + is_lts_driver = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642") + enable_remat_cache = False # not is_lts_driver + # optimize TTGIR intel.passes.ttgpuir.add_accelerate_matmul(pm) - intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache) intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) # FIXME: Use a better way to check if prefetch instructions are supported once available. # Prefetch instruction is not available in older drivers. @@ -170,13 +173,13 @@ def make_ttgir(mod, metadata, opt, device_arch): intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) passes.ttgpuir.add_coalesce(pm) - intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache) passes.ttgpuir.add_optimize_thread_locality(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) passes.common.add_cse(pm) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, True) - intel.passes.ttgpuir.add_remove_layout_conversions(pm) + intel.passes.ttgpuir.add_remove_layout_conversions(pm, enable_remat_cache) passes.ttgpuir.add_reduce_data_duplication(pm) passes.ttgpuir.add_reorder_instructions(pm) passes.common.add_cse(pm) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 613732bfdf..7535905b79 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -156,6 +156,10 @@ def TritonIntelGPURemoveLayoutConversions : Pass<"tritonintelgpu-remove-layout-c let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", "mlir::triton::TritonDialect"]; + let options = [ + Option<"enableRematCache", "enable-remat-cache", "bool", /*default*/"true", "Whether to enable the rematerialization cache (currently causing problems with the LTS driver).">, + ]; + } def TritonIntelGPURewriteTensorPointer : Pass<"tritonintelgpu-rewrite-tensor-pointer", "mlir::ModuleOp"> { diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 33c5ee3577..953368918b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -165,14 +165,14 @@ class LayoutRematerialization { return it->second; } void cleanup(); - void backwardRematerialization(); - void backwardRematerialization(ConvertLayoutOp convertOp); + void backwardRematerialization(bool enableRematCache); + void backwardRematerialization(ConvertLayoutOp convertOp, bool enableRematCache); void hoistConvertOnTopOfExtOrBroadcast(); void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); void rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp, IRMapping &mapping); + ConvertLayoutOp convertOp, IRMapping &mapping, bool enableRematCache); void rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp); + ConvertLayoutOp convertOp, bool enableRematCache); private: void updateRematMapping(SmallVector> &values); @@ -944,7 +944,7 @@ void LayoutRematerialization::updateRematMapping( void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, - IRMapping &mapping) { + IRMapping &mapping, bool enableRematCache) { SetVector opsToRewrite; // Keep track of yield operands that need to be duplicated. DenseMap> yieldOperandsMap; @@ -952,12 +952,12 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. -#if 0 // FIXME: Fails on LTS driver - if (hasRematValue(v, layoutIt->second)) { +#if 0 + if (false && hasRematValue(v, layoutIt->second)) { mapping.map(v, getRematValue(v, layoutIt->second)); continue; } -#endif +#endif if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); if (auto ifOp = v.getDefiningOp()) { @@ -1120,9 +1120,9 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp) { + ConvertLayoutOp convertOp, bool enableRematCache) { IRMapping mapping; - rewriteSlice(slice, layout, convertOp, mapping); + rewriteSlice(slice, layout, convertOp, mapping, enableRematCache); } LogicalResult getRematerializableSlice( @@ -1144,13 +1144,13 @@ LogicalResult getRematerializableSlice( return success(); } -void LayoutRematerialization::backwardRematerialization() { +void LayoutRematerialization::backwardRematerialization(bool enableRematCache) { // Go through each ConvertLayoutOp. SmallVector convertOps; funcOp.walk( [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { - backwardRematerialization(convertOp); + backwardRematerialization(convertOp, enableRematCache); } } @@ -1165,7 +1165,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() { } void LayoutRematerialization::backwardRematerialization( - ConvertLayoutOp convertOp) { + ConvertLayoutOp convertOp, bool enableRematCache) { RankedTensorType targetType = convertOp.getType(); // We don't backward propagate the dot layout with blocked layout as parent. // It introduces a lot of duplicated values in multiple-threads. @@ -1204,7 +1204,7 @@ void LayoutRematerialization::backwardRematerialization( DBGS() << " " << v << '\n'; }); // 2. Rewrite the slice. - rewriteSlice(slice, layout, convertOp); + rewriteSlice(slice, layout, convertOp, enableRematCache); } // For convert left we try to hoist them above type extension to reduce the cost @@ -1288,13 +1288,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); slice.remove(extOrBroadcatOp->getResult(0)); // 3. Rewrite the slice. - rewriteSlice(slice, layout, convertOp, mapping); + rewriteSlice(slice, layout, convertOp, mapping, /*enableRematCache=*/false); } -void backwardRematerialization(ModuleOp module) { - module.walk([](FuncOp funcOp) { +void backwardRematerialization(ModuleOp module, bool enableRematCache) { + module.walk([enableRematCache](FuncOp funcOp) { LayoutRematerialization layoutRemat(funcOp); - layoutRemat.backwardRematerialization(); + layoutRemat.backwardRematerialization(enableRematCache); layoutRemat.cleanup(); }); } @@ -1312,6 +1312,8 @@ class TritonIntelGPURemoveLayoutConversionsPass : public intel::impl::TritonIntelGPURemoveLayoutConversionsBase< TritonIntelGPURemoveLayoutConversionsPass> { public: + using TritonIntelGPURemoveLayoutConversionsBase::TritonIntelGPURemoveLayoutConversionsBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp m = getOperation(); @@ -1343,7 +1345,7 @@ class TritonIntelGPURemoveLayoutConversionsPass // 2. For remaining convert ops, try to rematerialize the slice of producer // operation to avoid having to convert. - backwardRematerialization(m); + backwardRematerialization(m, enableRematCache); LLVM_DEBUG({ DBGS() << "Module after backward remat:\n"; m.dump(); diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 4dfd892463..54d804515d 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -1,4 +1,4 @@ -#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassManager.h" #include "passes.h" #include "llvm/IRReader/IRReader.h" @@ -65,8 +65,8 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createIntelAllocateSharedMemory); ADD_PASS_WRAPPER_OPT_2("add_pipeline", gpu::intel::createTritonIntelGPUPipeline, int, bool); - ADD_PASS_WRAPPER_0("add_remove_layout_conversions", - gpu::intel::createTritonIntelGPURemoveLayoutConversions); + ADD_PASS_WRAPPER_OPT_1("add_remove_layout_conversions", + gpu::intel::createTritonIntelGPURemoveLayoutConversions, bool); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", gpu::intel::createTritonIntelGPURewriteTensorPointer); ADD_PASS_WRAPPER_OPT_2("add_prefetch_block", From 452f4ffe4ea28db93378766be701965971de742c Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 7 Jun 2024 13:43:37 -0700 Subject: [PATCH 4/5] Conditionally disable remat cache under LTS driver --- third_party/intel/backend/compiler.py | 2 +- .../RemoveLayoutConversions.cpp | 19 +++++++++++-------- third_party/intel/triton_xpu.cc | 5 +++-- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index c2cf725c9a..f9a7947f93 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -161,7 +161,7 @@ def make_ttgir(mod, metadata, opt, device_arch): passes.ttir.add_convert_to_ttgpuir(pm, f"xpu:{device_arch}", opt.num_warps, opt.threads_per_warp, opt.num_ctas) is_lts_driver = Version(metadata["target"].arch['driver_version']) == Version("1.3.27642") - enable_remat_cache = False # not is_lts_driver + enable_remat_cache = not is_lts_driver # optimize TTGIR intel.passes.ttgpuir.add_accelerate_matmul(pm) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 953368918b..04034d08e6 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -166,11 +166,13 @@ class LayoutRematerialization { } void cleanup(); void backwardRematerialization(bool enableRematCache); - void backwardRematerialization(ConvertLayoutOp convertOp, bool enableRematCache); + void backwardRematerialization(ConvertLayoutOp convertOp, + bool enableRematCache); void hoistConvertOnTopOfExtOrBroadcast(); void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp); void rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp, IRMapping &mapping, bool enableRematCache); + ConvertLayoutOp convertOp, IRMapping &mapping, + bool enableRematCache); void rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, bool enableRematCache); @@ -944,7 +946,8 @@ void LayoutRematerialization::updateRematMapping( void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, ConvertLayoutOp convertOp, - IRMapping &mapping, bool enableRematCache) { + IRMapping &mapping, + bool enableRematCache) { SetVector opsToRewrite; // Keep track of yield operands that need to be duplicated. DenseMap> yieldOperandsMap; @@ -952,12 +955,10 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, auto layoutIt = layout.find(v); assert(layoutIt != layout.end()); // If we already have a remat value for this value, use it. -#if 0 - if (false && hasRematValue(v, layoutIt->second)) { + if (enableRematCache && hasRematValue(v, layoutIt->second)) { mapping.map(v, getRematValue(v, layoutIt->second)); continue; } -#endif if (v.getDefiningOp()) { opsToRewrite.insert(v.getDefiningOp()); if (auto ifOp = v.getDefiningOp()) { @@ -1120,7 +1121,8 @@ void LayoutRematerialization::rewriteSlice(SetVector &slice, void LayoutRematerialization::rewriteSlice(SetVector &slice, DenseMap &layout, - ConvertLayoutOp convertOp, bool enableRematCache) { + ConvertLayoutOp convertOp, + bool enableRematCache) { IRMapping mapping; rewriteSlice(slice, layout, convertOp, mapping, enableRematCache); } @@ -1312,7 +1314,8 @@ class TritonIntelGPURemoveLayoutConversionsPass : public intel::impl::TritonIntelGPURemoveLayoutConversionsBase< TritonIntelGPURemoveLayoutConversionsPass> { public: - using TritonIntelGPURemoveLayoutConversionsBase::TritonIntelGPURemoveLayoutConversionsBase; + using TritonIntelGPURemoveLayoutConversionsBase:: + TritonIntelGPURemoveLayoutConversionsBase; void runOnOperation() override { MLIRContext *context = &getContext(); diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 54d804515d..be826d35c1 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -65,8 +65,9 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createIntelAllocateSharedMemory); ADD_PASS_WRAPPER_OPT_2("add_pipeline", gpu::intel::createTritonIntelGPUPipeline, int, bool); - ADD_PASS_WRAPPER_OPT_1("add_remove_layout_conversions", - gpu::intel::createTritonIntelGPURemoveLayoutConversions, bool); + ADD_PASS_WRAPPER_OPT_1( + "add_remove_layout_conversions", + gpu::intel::createTritonIntelGPURemoveLayoutConversions, bool); ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", gpu::intel::createTritonIntelGPURewriteTensorPointer); ADD_PASS_WRAPPER_OPT_2("add_prefetch_block", From a7aa8bf83cbb15aae794d676752d4f19483bc8f1 Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Fri, 7 Jun 2024 13:50:29 -0700 Subject: [PATCH 5/5] Keep remat cache enabled for hoist convert layout materialization phase This phase does not affect LTS driver --- .../lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 04034d08e6..24ca2ef748 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -1290,7 +1290,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( mapping.map(extOrBroadcatOp->getResult(0), newExtOrBroadcast->getResult(0)); slice.remove(extOrBroadcatOp->getResult(0)); // 3. Rewrite the slice. - rewriteSlice(slice, layout, convertOp, mapping, /*enableRematCache=*/false); + rewriteSlice(slice, layout, convertOp, mapping, /*enableRematCache=*/true); } void backwardRematerialization(ModuleOp module, bool enableRematCache) {