diff --git a/python/src/llvm.cc b/python/src/llvm.cc index 4fa590db6355..84c2efbb740a 100644 --- a/python/src/llvm.cc +++ b/python/src/llvm.cc @@ -647,7 +647,7 @@ void init_triton_llvm(py::module &&m) { "optimize_module", [](llvm::Module *mod, const llvm::OptimizationLevel &opt, std::string arch, std::string features, std::vector flags, - bool enable_fp_fusion) { + bool enable_fp_fusion, bool disable_vector_combine) { if (mlir::triton::tools::getBoolEnv("DISABLE_LLVM_OPT")) return; // Check to see if we are passing a list of flags to disable @@ -678,11 +678,24 @@ void init_triton_llvm(py::module &&m) { PassInstrumentationCallbacks passInstrCb; StandardInstrumentations standardInstr(mod->getContext(), /*DebugLogging*/ true); + bool enablePassInstrumentation = false; + if (disable_vector_combine) { + // VectorCombinePass::name() returns the C++ class name, not the + // registry name "vector-combine". + const StringRef kVectorCombinePassName = "VectorCombinePass"; + passInstrCb.registerShouldRunOptionalPassCallback( + [kVectorCombinePassName](StringRef passName, Any) { + return passName != kVectorCombinePassName; + }); + enablePassInstrumentation = true; + } if (mlir::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) { setLLVMOption("print-after-all", true); standardInstr.registerCallbacks(passInstrCb, &mam); - instrCbPtr = &passInstrCb; + enablePassInstrumentation = true; } + if (enablePassInstrumentation) + instrCbPtr = &passInstrCb; PipelineTuningOptions tuningOptions; tuningOptions.LoopUnrolling = true; @@ -760,6 +773,7 @@ void init_triton_llvm(py::module &&m) { py::arg("arch") = "", py::arg("features") = "", py::arg("flags") = std::vector{}, py::arg("enable_fp_fusion") = false, + py::arg("disable_vector_combine") = false, py::call_guard()); m.def( diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index a61f38e021c2..1c12c74f3acf 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -479,7 +479,7 @@ def make_llir(src, metadata, options): if len(paths) > 0: llvm.link_extern_libs(llvm_mod, paths) - llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion) + llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3, options.arch, '', [], options.enable_fp_fusion, True) # Architectures with architected SGPRs store the workgroup id in ttmp9 (X) and ttmp7 (Y[15:0], Z[31:16]). # These attributes are used to determine if Z should be masked out when loading Y. They are inferred during