diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp index a2fa1baec953..4748f8fbe848 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp @@ -51,7 +51,8 @@ static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; // Key APIs: TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { +tritonAddPluginPass(mlir::PassManager *pm, const char *passName, + const std::vector &args) { std::string passNameStr(passName); if (passMap.find(passNameStr) == passMap.end()) return TP_GENERIC_FAILURE; diff --git a/examples/plugins/Passes.td b/examples/plugins/Passes.td index a8007a09e84b..8971e5bdc678 100644 --- a/examples/plugins/Passes.td +++ b/examples/plugins/Passes.td @@ -5,5 +5,12 @@ include "mlir/Pass/PassBase.td" def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> { let summary = "Triton MLIR Plugin Pass"; + + let options = [ + Option<"num_warps", "num-warps", + "int32_t", /*default*/"4", + "Number of warps">, + ]; + } #endif diff --git a/examples/plugins/TritonPlugin.cpp b/examples/plugins/TritonPlugin.cpp index a714d18618f5..c79c60d91ee1 100644 --- a/examples/plugins/TritonPlugin.cpp +++ b/examples/plugins/TritonPlugin.cpp @@ -10,17 +10,27 @@ namespace mlir { namespace triton { namespace plugin { +#define GEN_PASS_DECL_TRITONGPUMLIRPLUGIN #define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN #include "Passes.h.inc" struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { + using TritonGPUMLIRPluginBase::TritonGPUMLIRPluginBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + + std::string name; + llvm::raw_string_ostream sstr(name); + sstr << "foo"; + if (num_warps != 4) + sstr << "_num_warps_" << num_warps; + mod.walk([&](FunctionOpInterface funcOp) { StringAttr funcNameAttr = funcOp.getNameAttr(); - funcOp.setName("foo"); + funcOp.setName(name); }); } }; @@ -29,8 +39,16 @@ struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { } // namespace triton } // namespace mlir -static void addTritonPluginPass(mlir::PassManager *pm) { - pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); +static void addTritonPluginPass(mlir::PassManager *pm, + const std::vector &args) { + if (args.empty()) { + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); + return; + } + + mlir::triton::plugin::TritonGPUMLIRPluginOptions opts; + opts.num_warps = std::atoi(args[0].c_str()); + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin((opts))); } static void registerTritonPluginPass() { @@ -40,20 +58,21 @@ static void registerTritonPluginPass() { } static const char *ADD_PLUGIN_PASS_NAME = "add_plugin"; -static std::unordered_map passMap = +static std::unordered_map passMap = {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map registryMap = { - {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; +static std::unordered_map + registryMap = {{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; // Key APIs: TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { +tritonAddPluginPass(mlir::PassManager *pm, const char *passName, + const std::vector &args) { std::string passNameStr(passName); if (passMap.find(passNameStr) == passMap.end()) return TP_GENERIC_FAILURE; - passMap[passNameStr](pm); + passMap[passNameStr](pm, args); return TP_SUCCESS; } diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index 5878af01bb91..77b921ae6b48 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -36,10 +36,10 @@ struct TritonPlugin { using EnumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, const char **); - using AddPassType = - std::function; - using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, - const char *); + using AddPassType = std::function &)>; + using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, const char *, + const std::vector &); using RegisterPassType = std::function; using RegisterPassCType = TritonPluginResult (*)(const char *); @@ -78,8 +78,9 @@ struct TritonPlugin { llvm::Expected getDialectHandles(std::vector &handles); - llvm::Expected addPass(mlir::PassManager *pm, - const char *passHandle); + llvm::Expected + addPass(mlir::PassManager *pm, const char *passHandle, + const std::vector &args); llvm::Expected registerPass(const char *passHandle); diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp index c439086666f2..f9ef78d9f2e8 100644 --- a/lib/Tools/PluginUtils.cpp +++ b/lib/Tools/PluginUtils.cpp @@ -138,10 +138,11 @@ TritonPlugin::getDialectHandles(std::vector &dialectNames) { } llvm::Expected -TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) { +TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle, + const std::vector &args) { if (auto Err = loadPlugin()) return Err; - return checkAPIResult(addPassAPI(pm, passHandle), passHandle); + return checkAPIResult(addPassAPI(pm, passHandle, args), passHandle); } llvm::Expected diff --git a/python/src/passes.cc b/python/src/passes.cc index fae2c8f95d16..836812afd096 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -115,13 +115,16 @@ void init_plugin_passes(py::module &&m) { for (unsigned i = 0; i < passNames.size(); ++i) { const char *passName = passNames.data()[i]; - m.def(passName, [passName](mlir ::PassManager &pm) { - std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - TritonPlugin TP(filename); - if (auto result = TP.addPass(&pm, passName); !result) - throw TP.err2exp(result.takeError()); - }); + m.def( + passName, + [passName](mlir ::PassManager &pm, std::vector args) { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + TritonPlugin TP(filename); + if (auto result = TP.addPass(&pm, passName, args); !result) + throw TP.err2exp(result.takeError()); + }, + py::arg("pm"), py::arg("args") = std::vector()); } } diff --git a/python/test/unit/plugins/custom_stages.py b/python/test/unit/plugins/custom_stages.py index f3fda7c4db50..de52de3bf289 100644 --- a/python/test/unit/plugins/custom_stages.py +++ b/python/test/unit/plugins/custom_stages.py @@ -20,6 +20,8 @@ def get_hash(): # Keep custom pipeline stages in a seperate file from kernels as any change to the file # will trigger a recompile. +num_warps = 4 + def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None): # If the hook is called with no arguments we assume were just after the key and hash and don't want to @@ -31,7 +33,10 @@ def make_ttir_wrapper(mod, metadata, opt, capability): mod = self.make_ttir(mod, metadata, opt, capability) pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.plugin.add_plugin(pm) + if num_warps != 4: + passes.plugin.add_plugin(pm, {str(num_warps)}) + else: + passes.plugin.add_plugin(pm) pm.run(mod, 'make_ttir_plugin') return mod diff --git a/python/test/unit/plugins/test_plugin.py b/python/test/unit/plugins/test_plugin.py index 9a895174b1b8..c56c17b3af97 100644 --- a/python/test/unit/plugins/test_plugin.py +++ b/python/test/unit/plugins/test_plugin.py @@ -21,6 +21,12 @@ def kernel2(BLOCK_SIZE: tl.constexpr): return +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel3(BLOCK_SIZE: tl.constexpr): + return + + def test_op(capfd, device: str): if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': return @@ -41,3 +47,8 @@ def test_op(capfd, device: str): knobs.runtime.add_stages_inspection_hook = None h = kernel2[grid](BLOCK_SIZE=1024) assert "tt.func public @foo" not in h.asm["ttgir"] + + knobs.runtime.add_stages_inspection_hook = custom_stages.inspect_stages_hook + custom_stages.num_warps = 8 + h = kernel3[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo_num_warps_8" in h.asm["ttgir"]