diff --git a/Makefile b/Makefile index 4425b057d97d..1fba4e28eea5 100644 --- a/Makefile +++ b/Makefile @@ -40,11 +40,11 @@ test-unit: all $(PYTEST) python/tutorials/06-fused-attention.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \ $(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py - TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \ + TRITON_PLUGIN_PATHS=python/triton/plugins/libTritonPluginsTestLib.so \ $(PYTEST) -vvv python/test/unit/plugins/test_plugin.py - TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \ + TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \ $(PYTEST) -vvv python/test/unit/plugins/test_dialect_plugin.py - TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \ + TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \ $(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py .PHONY: test-gluon diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 0a3b2dd3451a..da8646a376f4 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -149,32 +149,10 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::proton::gpu::registerScheduleBufferStorePass(); mlir::triton::proton::gpu::registerAddSchedBarriersPass(); - // Plugin passes - if (std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - !filename.empty()) { - - TritonPlugin TP(filename); - std::vector passNames; - if (auto result = TP.getPassHandles(passNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (const char *passName : passNames) - if (auto result = TP.registerPass(passName); !result) - llvm::report_fatal_error(result.takeError()); - - std::vector dialectNames; - if (auto result = TP.getDialectHandles(dialectNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (unsigned i = 0; i < dialectNames.size(); ++i) { - const char *dialectName = dialectNames.data()[i]; - auto result = TP.getDialectPluginInfo(dialectName); - if (!result) - llvm::report_fatal_error(result.takeError()); - ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; - dialectPluginInfo.registerDialectRegistryCallbacks(®istry); - } + // Register plugin passes and dialects. + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + plugin.registerPasses(); + plugin.registerDialects(registry); } registry.insert< diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp index b605b92de1b1..0349a8a9a1a5 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp @@ -3,6 +3,7 @@ #include "DialectPlugin/DialectPluginTypes.h" using namespace mlir; +using namespace mlir::triton; using namespace mlir::triton::plugin; #include "DialectPlugin/DialectPluginOpsDialect.cpp.inc" @@ -30,9 +31,13 @@ void DialectPluginDialect::initialize() { #include "triton/Tools/PluginUtils.h" #include "llvm/Config/llvm-config.h" -using namespace mlir; +static const char *PLUGIN_NAME = "DialectPlugin"; +static const char *DIALECT_NAME = "DialectPlugin"; +static const char *PASS_NAME = "plugingpu_conversion"; +static const char *VERSION = "0.1.0"; -static void addTritonPluginPass(mlir::PassManager *pm) { +static void addTritonPluginPass(mlir::PassManager *pm, + const std::vector &args) { pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass()); } @@ -42,88 +47,37 @@ static void registerTritonPluginPass() { }); } -static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion"; -static std::unordered_map passMap = - {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map registryMap = { - {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; -static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; - -// Key APIs: - -TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName, - const std::vector &args) { - std::string passNameStr(passName); - if (passMap.find(passNameStr) == passMap.end()) - return TP_GENERIC_FAILURE; - passMap[passNameStr](pm); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonRegisterPluginPass(const char *passName) { - std::string passNameStr(passName); - if (registryMap.find(passNameStr) == registryMap.end()) - return TP_GENERIC_FAILURE; - registryMap[passNameStr](); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { - if (!passCount) - return TP_GENERIC_FAILURE; - auto count = passMap.size(); - assert(count == registryMap.size() && - "Expected register and add passes map size to match"); - *passCount = count; - if (!passNames) - return TP_SUCCESS; - unsigned i = 0; - for (auto passName : passNamesTable) { - passNames[i] = passName; - } - return TP_SUCCESS; +static void registerTritonPluginDialect(DialectRegistry *registry) { + registry->insert(); + mlir::triton::plugin::registerpluginPasses(); } -TRITON_PLUGIN_API -tritonEnumeratePluginDialects(uint32_t *dialectCount, - const char **dialectNames) { - *dialectCount = 1; - if (!dialectNames) - return TP_SUCCESS; - dialectNames[0] = "DialectPlugin"; - return TP_SUCCESS; -} - -TRITON_PLUGIN_API_TYPE(DialectPluginLibraryInfo) -tritonGetDialectPluginInfo(const char *name) { - return {MLIR_PLUGIN_API_VERSION, "DialectPlugin", LLVM_VERSION_STRING, - [](DialectRegistry *registry) { - registry->insert(); - mlir::triton::plugin::registerpluginPasses(); - }}; -} - -TRITON_PLUGIN_API -tritonEnumeratePluginCustomOps(uint32_t *count, const char **handles) { - if (!count) - return TP_GENERIC_FAILURE; - *count = 1; - if (!handles) - return TP_SUCCESS; - handles[0] = "create_custom_op"; - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonAddPluginCustomOp(const char *handle, TritonOpBuilder &self, - std::vector &operands) { +static void addTritonPluginCustomOp(TritonOpBuilder &self, + std::vector &operands) { ::mlir::Value &dst = operands[0]; ::mlir::Value &src = operands[1]; dst = self.create(src, src); operands[0] = dst; - return TP_SUCCESS; +} + +TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { + static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass, + registerTritonPluginPass}; + static plugin::PassInfo passes[] = {pass}; + static plugin::DialectInfo dialect = {DIALECT_NAME, VERSION, + registerTritonPluginDialect}; + static plugin::DialectInfo dialects[] = {dialect}; + static plugin::OpInfo op = {"create_custom_op", addTritonPluginCustomOp}; + static plugin::OpInfo ops[] = {op}; + static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION, + PLUGIN_NAME, + VERSION, + passes, + 1, + dialects, + 1, + ops, + 1}; + return &info; } diff --git a/examples/plugins/README.md b/examples/plugins/README.md index a5523935cd1d..ebd511bef5b0 100644 --- a/examples/plugins/README.md +++ b/examples/plugins/README.md @@ -18,7 +18,7 @@ long as the libtriton.so is linked to the plugin and the Triton include passes a ## Example 1: Developing a custom pass and running triton-opt to inspect the modified IR ``` bash export TRITON_EXT_ENABLED=1; make dev-install-llvm -TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir +TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir ``` ``` MLIR module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { @@ -85,7 +85,7 @@ Running same code but loading the plugin library also produces the same results pass manager it is not inserted into the compiler pass pipeline: ``` bash -TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py ``` ``` MLIR @@ -151,7 +151,7 @@ if __name__ == '__main__': h = kernel[grid](BLOCK_SIZE=1024) print(h.asm["ttgir"]) - if "TRITON_PASS_PLUGIN_PATH" in os.environ: + if "TRITON_PLUGIN_PATHS" in os.environ: knobs.runtime.add_stages_inspection_hook = inspect_stages_hook h = kernel[grid](BLOCK_SIZE=1024) print(h.asm["ttgir"]) @@ -163,7 +163,7 @@ if __name__ == '__main__': ``` ``` bash -TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py ``` Shows the pass ran and modified the kernel name but only after the hook is set. Any kernels before the hook or after the hook is unset are left unchanged. @@ -330,7 +330,7 @@ if __name__ == '__main__': if "add_loop_unroll" in line: outfile.write("\n passes.plugin.add_plugin(pm)\n") outfile.write(line) - if "TRITON_PASS_PLUGIN_PATH" in os.environ: + if "TRITON_PLUGIN_PATHS" in os.environ: knobs.runtime.add_stages_inspection_hook = override_stages h = kernel2[grid](BLOCK_SIZE=1024) print(h.asm["ttgir"]) diff --git a/examples/plugins/TritonPlugin.cpp b/examples/plugins/TritonPlugin.cpp index c79c60d91ee1..6c18f0c3a073 100644 --- a/examples/plugins/TritonPlugin.cpp +++ b/examples/plugins/TritonPlugin.cpp @@ -57,47 +57,24 @@ static void registerTritonPluginPass() { }); } -static const char *ADD_PLUGIN_PASS_NAME = "add_plugin"; -static std::unordered_map passMap = - {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map - registryMap = {{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; -static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; - -// Key APIs: - -TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName, - const std::vector &args) { - std::string passNameStr(passName); - if (passMap.find(passNameStr) == passMap.end()) - return TP_GENERIC_FAILURE; - passMap[passNameStr](pm, args); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonRegisterPluginPass(const char *passName) { - std::string passNameStr(passName); - if (registryMap.find(passNameStr) == registryMap.end()) - return TP_GENERIC_FAILURE; - registryMap[passNameStr](); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { - if (!passCount) - return TP_GENERIC_FAILURE; - auto count = passMap.size(); - assert(count == registryMap.size() && - "Expected register and add passes map size to match"); - *passCount = count; - if (!passNames) - return TP_SUCCESS; - unsigned i = 0; - for (auto passName : passNamesTable) { - passNames[i++] = passName; - } - return TP_SUCCESS; +static const char *PLUGIN_NAME = "TritonPlugin"; +static const char *PASS_NAME = "add_plugin"; +static const char *VERSION = "0.1.0"; + +using namespace mlir::triton; + +TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { + static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass, + registerTritonPluginPass}; + static plugin::PassInfo passes[] = {pass}; + static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION, + PLUGIN_NAME, + VERSION, + passes, + 1, + nullptr, + 0, + nullptr, + 0}; + return &info; } diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index 86536948105b..fd1c149942f7 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -1,124 +1,197 @@ +// Defines the external and internal interface for Triton plugins. +// +// This is largely meant to follow the plugin pattern outlined in upstream MLIR +// ([DialectPlugin], [PassPlugin]); use those as references for further +// additions. +// +// [DialectPlugin]: +// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/DialectPlugin.h +// [PassPlugin]: +// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h + #ifndef TRITON_PLUGIN_UTILS_H #define TRITON_PLUGIN_UTILS_H +#include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/PassManager.h" #include "mlir/Tools/Plugins/DialectPlugin.h" #include "python/src/ir.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/Error.h" #include +#include + +/// Identifies the API version understood by this plugin. +/// +/// This version should be incremented for ABI-breaking changes in the structs +/// below; we check this version when loading a new \c TritonPlugin. See +/// similar: [MLIR_PLUGIN_API_VERSION]. +/// +/// [MLIR_PLUGIN_API_VERSION]: +/// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h#L32 +#define TRITON_PLUGIN_API_VERSION 1 + +/// Use this helper macro on the public entry point for a Triton plugin. +#define TRITON_PLUGIN_API extern "C" __attribute__((visibility("default"))) + +namespace mlir::triton::plugin { + +// Types for plugin callback functions. +using AddPassCallback = void (*)(mlir::PassManager *, + const std::vector &); +using RegisterPassCallback = void (*)(); +using RegisterDialectCallback = void (*)(mlir::DialectRegistry *); +using AddOpCallback = void (*)(TritonOpBuilder &, std::vector &); + +/// Information provided by a plugin for loading its passes. +struct PassInfo { + const char *name; + const char *version; + AddPassCallback addPass; + RegisterPassCallback registerPass; +}; -extern "C" { -enum TritonPluginResult { - TP_SUCCESS = 0, - TP_GENERIC_FAILURE = 1, +/// Information provided by a plugin for loading its dialects. +struct DialectInfo { + const char *name; + const char *version; + RegisterDialectCallback registerDialect; }; + +/// Information provided by a plugin for loading its custom ops. +struct OpInfo { + const char *name; + AddOpCallback addOp; }; -#define TRITON_PLUGIN_API \ - extern "C" __attribute__((visibility("default"))) TritonPluginResult -#define TRITON_PLUGIN_API_TYPE(_TYPE) \ - extern "C" __attribute__((visibility("default"))) _TYPE -struct TritonPlugin { - TritonPlugin() = delete; - TritonPlugin(std::string filename) : filename(filename) {} +/// Container for all plugin information; this is returned by the plugin +/// library's public entry point, @ref tritonGetPluginInfo. +struct PluginInfo { + /// The API version used by this plugin, see \c TRITON_PLUGIN_API_VERSION. + uint32_t apiVersion; -public: - llvm::Error checkLibraryValid(const std::string &error) const; - static constexpr char ENUMERATE_PASSES[] = "tritonEnumeratePluginPasses"; - static constexpr char ENUMERATE_DIALECTS[] = "tritonEnumeratePluginDialects"; - static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo"; - static constexpr char ADD_PASS[] = "tritonAddPluginPass"; - static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass"; - static constexpr char ENUMERATE_CUSTOMOPS[] = - "tritonEnumeratePluginCustomOps"; - static constexpr char ADD_CUSTOMOP[] = "tritonAddPluginCustomOp"; + /// A meaningful name of the plugin. + const char *pluginName; + /// The version of the plugin. + const char *pluginVersion; -private: - using EnumeratePyBindHandlesType = - std::function; - using EnumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, - const char **); - - using AddPassType = std::function &)>; - using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, const char *, - const std::vector &); - - using RegisterPassType = std::function; - using RegisterPassCType = TritonPluginResult (*)(const char *); - - using DialectPluginInfoType = - std::function<::mlir::DialectPluginLibraryInfo(const char *)>; - using DialectPluginInfoCType = - ::mlir::DialectPluginLibraryInfo (*)(const char *); - - using AddCustomOpType = std::function &operands)>; - using AddCustomOpCType = - TritonPluginResult (*)(const char *handle, TritonOpBuilder &self, - std::vector &operands); - - llvm::Expected getAddressOfSymbol(const std::string &symbol) const; - - template - llvm::Expected getAPI(const std::string &symbol) const { - llvm::Expected getDetailsFn = getAddressOfSymbol(symbol); - if (auto Err = getDetailsFn.takeError()) { - return Err; - } - auto func = reinterpret_cast(*getDetailsFn); - return func; - } - - llvm::Expected checkAPIResult(TritonPluginResult result, - const char *handle) const; - llvm::Expected - enumeratePyBindHandles(EnumeratePyBindHandlesType &enumeratePyBindHandles, - std::vector &passNames); + /// The list of passes. + PassInfo *passes; + size_t numPasses; + + /// The list of dialects. + DialectInfo *dialects; + size_t numDialects; + + /// The list of custom ops. + OpInfo *ops; + size_t numOps; +}; +/// A helper structure for storing information about a pass registered by a +/// plugin. +struct Pass { + Pass(const char *name, AddPassCallback addPass) + : name(name), addPass(addPass) {} + + const char *name; + const AddPassCallback addPass; +}; + +/// A helper structure for storing information about a pass registered by a +/// plugin. +struct Op { + Op(const char *name, AddOpCallback addOp) : name(name), addOp(addOp) {} + + const char *name; + const AddOpCallback addOp; +}; + +/// A loaded Triton plugin. +/// +/// An instance of this class wraps a loaded dialect plugin and gives access +/// to its interface defined by the \c PluginInfo it exposes. +class TritonPlugin { public: - std::runtime_error err2exp(llvm::Error Err); + /// Attempts to load a Triton plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or + /// loaded, there is no public entry point, or the plugin implements the + /// wrong API version. + static llvm::Expected load(const std::string &filename); - llvm::Error loadPlugin(); + /// Get the filename of the loaded plugin. + llvm::StringRef getFilename() const { return filename; } - llvm::Expected - getPassHandles(std::vector &handles); + /// Get the plugin name. + llvm::StringRef getPluginName() const { return info->pluginName; } - llvm::Expected - getDialectHandles(std::vector &handles); + /// Get the plugin version. + llvm::StringRef getPluginVersion() const { return info->pluginVersion; } - llvm::Expected - getCustomOpHandles(std::vector &handles); + /// Get the plugin API version. + uint32_t getAPIVersion() const { return info->apiVersion; } - llvm::Expected - addPass(mlir::PassManager *pm, const char *passHandle, - const std::vector &args); + /// List the available passes; this allows us invoke the \c AddPassCallback + /// while knowing the pass name. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + const std::vector listPasses() const; - llvm::Expected - addCustomOp(const char *handle, TritonOpBuilder &self, - std::vector &operands); + /// Invoke the \c RegisterPassCallback for each pass registered in this + /// plugin. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + void registerPasses() const; - llvm::Expected registerPass(const char *passHandle); + /// Invoke the \c RegisterDialectCallback for each dialect registered in + /// this plugin. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + void registerDialects(DialectRegistry &dialectRegistry) const; - llvm::Expected<::mlir::DialectPluginLibraryInfo> - getDialectPluginInfo(const char *dialectName); + /// List the custom operations; this allows us invoke the \c + /// AddOpCallback while knowing the operation name. This function will crash + /// with an LLVM usage error if the plugin provides invalid \c PluginInfo. + const std::vector listOps() const; private: - std::string filename = ""; - mutable llvm::sys::DynamicLibrary library; - EnumeratePyBindHandlesType enumeratePassesAPI; - EnumeratePyBindHandlesType enumerateDialectsAPI; - EnumeratePyBindHandlesType enumerateCustomOpAPI; - AddPassType addPassAPI; - RegisterPassType registerPassAPI; - DialectPluginInfoType dialectPluginInfoAPI; - AddCustomOpType addCustomOpAPI; - bool isLoaded = false; + TritonPlugin(const std::string &filename, + const llvm::sys::DynamicLibrary &library) + : filename(filename), library(library), info() {} + + std::string filename; + llvm::sys::DynamicLibrary library; + PluginInfo *info; }; -void loadPluginDialects(const std::string &filename, - mlir::DialectRegistry ®istry); +/// Load all plugins specified in the `TRITON_PLUGIN_PATHS` environment +/// variable. This variable should contain a colon-separated list of paths to +/// plugin shared libraries. +/// +/// \returns Returns the list of successfully loaded plugins. If any plugin +/// fails to load, it crashes with an LLVM usage error. +const std::vector &loadPlugins(); + +} // namespace mlir::triton::plugin + +/// The public entry point for loading a Triton plugin. +/// +/// When a plugin is loaded by the driver, Triton will call this entry point to +/// obtain information about the plugin and how to load it. This function must +/// to be implemented by the plugin. +/// +/// Triton expects this function to return a pointer to a valid \c PluginInfo +/// struct. Because plugins are loaded in-process permanently, the \c PluginInfo +/// struct has a lifetime spanning the duration of the program; thus, no +/// deallocation function is required from the plugin. As an extra precaution +/// against leaks, return a pointer to a static struct: +/// +/// ``` +/// mlir::triton::plugin::PluginInfo *tritonGetPluginInfo() { +/// static mlir::triton::plugin::PluginInfo info = { ... }; +/// return &info; +/// } +/// ``` +extern "C" mlir::triton::plugin::PluginInfo *LLVM_ATTRIBUTE_WEAK +tritonGetPluginInfo(); #endif // TRITON_PLUGIN_UTILS_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index a152436806c3..4cc051c6219f 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -46,7 +46,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", - "TRITON_PASS_PLUGIN_PATH", + "TRITON_PLUGIN_PATHS", "TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT", "TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY", "TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY", diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp index c33c6c3123a1..6311907cde91 100644 --- a/lib/Tools/PluginUtils.cpp +++ b/lib/Tools/PluginUtils.cpp @@ -1,236 +1,163 @@ #include "triton/Tools/PluginUtils.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" -llvm::Error TritonPlugin::checkLibraryValid(const std::string &error) const { - if (!library.isValid()) { - return llvm::createStringError( - llvm::Twine("Failed to load plugin library: ") + error); - } - return llvm::Error::success(); -} - -llvm::Expected -TritonPlugin::getAddressOfSymbol(const std::string &symbol) const { - if (auto isValid = checkLibraryValid("not loaded")) - return isValid; - intptr_t getDetailsFn = (intptr_t)library.getAddressOfSymbol(symbol.c_str()); - if (!getDetailsFn) { - return llvm::createStringError(llvm::Twine("Failed to get symbol: ") + - symbol); - } - return getDetailsFn; -} - -llvm::Expected -TritonPlugin::checkAPIResult(TritonPluginResult result, - const char *handle) const { - if (result == TP_SUCCESS) - return TP_SUCCESS; - return llvm::createStringError( - llvm::Twine("Failed to add/register a plugin pass (") + handle + - "), error code: " + std::to_string(result)); -} - -std::runtime_error TritonPlugin::err2exp(llvm::Error Err) { - std::string msg; - llvm::raw_string_ostream os(msg); - os << Err; - return std::runtime_error(msg); -} +#define DEBUG_TYPE "triton-plugins" -llvm::Error TritonPlugin::loadPlugin() { - // Bailing when libtriton symbols are not visible is done to prevent - // crashes caused the loading of plugins (from a set TRITON_PASS_PLUGIN_PATH - // env var path) who will never find their dependent symbols (which are hidden - // by libtriton). -#if !defined(TRITON_EXT_ENABLED) || TRITON_EXT_ENABLED == 0 - // Right now we only support one extension, bump this up if that changes - static llvm::SmallVector printedWarning; - if (llvm::find(printedWarning, filename) == printedWarning.end()) { - llvm::errs() << "\n" - << "\n=================== WARNING =====================\n" - << "Triton will not load the following extension\n" - << "because it is not built with TRITON_EXT_ENABLED:\n" - << filename - << "\n=================================================\n" - << "\n"; - printedWarning.push_back(filename); - } - return llvm::Error::success(); -#endif - - if (isLoaded) - return llvm::Error::success(); +using namespace mlir::triton::plugin; +llvm::Expected TritonPlugin::load(const std::string &filename) { std::string error; - library = + auto library = llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); - if (auto isValid = checkLibraryValid(error)) - return isValid; - - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES)) { - auto enumeratePassesAPIOrErr = - getAPI( - ENUMERATE_PASSES); - auto addPassAPIOrErr = getAPI(ADD_PASS); - auto registerPassAPIOrErr = - getAPI(REGISTER_PASS); - - if (auto Err = enumeratePassesAPIOrErr.takeError()) - return Err; - if (auto Err = addPassAPIOrErr.takeError()) - return Err; - if (auto Err = registerPassAPIOrErr.takeError()) - return Err; - - addPassAPI = *addPassAPIOrErr; - registerPassAPI = *registerPassAPIOrErr; - enumeratePassesAPI = *enumeratePassesAPIOrErr; - } - - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS)) { - auto enumerateDialectsAPIOrErr = - getAPI( - ENUMERATE_DIALECTS); - auto dialectPluginInfoAPIOrErr = - getAPI( - DIALECT_PLUGININFO); - - if (auto Err = enumerateDialectsAPIOrErr.takeError()) - return Err; - if (auto Err = dialectPluginInfoAPIOrErr.takeError()) - return Err; - enumerateDialectsAPI = *enumerateDialectsAPIOrErr; - dialectPluginInfoAPI = *dialectPluginInfoAPIOrErr; - } - - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_CUSTOMOPS)) { - auto enumerateCustomOpAPIOrErr = - getAPI( - ENUMERATE_CUSTOMOPS); - auto addCustomOpAPIOrErr = - getAPI(ADD_CUSTOMOP); - - if (auto Err = enumerateCustomOpAPIOrErr.takeError()) - return Err; - if (auto Err = addCustomOpAPIOrErr.takeError()) - return Err; - - enumerateCustomOpAPI = *enumerateCustomOpAPIOrErr; - addCustomOpAPI = *addCustomOpAPIOrErr; - } - - isLoaded = true; - return llvm::Error::success(); + if (!library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + filename + "': " + error, + llvm::inconvertibleErrorCode()); + + TritonPlugin plugin{filename, library}; + + // tritonGetPluginInfo should be resolved to the definition from the + // plugin we are currently loading. + intptr_t getInfoFn = + (intptr_t)library.getAddressOfSymbol("tritonGetPluginInfo"); + if (!getInfoFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + filename + "'.", + llvm::inconvertibleErrorCode()); + + plugin.info = reinterpret_cast(getInfoFn)(); + + if (plugin.info->apiVersion != TRITON_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + filename + "'. Got version " + + Twine(plugin.info->apiVersion) + ", supported version is " + + Twine(TRITON_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + return plugin; } -llvm::Expected TritonPlugin::enumeratePyBindHandles( - EnumeratePyBindHandlesType &enumeratePyBindHandles, - std::vector &handles) { - if (auto Err = loadPlugin()) - return Err; - - uint32_t passCount = 0; - handles.clear(); - auto result = enumeratePyBindHandles(&passCount, nullptr); - if (result == TP_SUCCESS) { - if (passCount == 0) - return TP_SUCCESS; - - handles.resize(passCount); - result = enumeratePyBindHandles(&passCount, handles.data()); +const std::vector TritonPlugin::listPasses() const { + if (!info->passes && info->numPasses > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid pass pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Listing " << info->numPasses + << " passes for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + std::vector passes; + for (auto i = 0; i < info->numPasses; ++i) { + const auto pass = &info->passes[i]; + if (pass->addPass) { + LLVM_DEBUG(llvm::dbgs() << "Listing pass " << pass->name << ":" + << pass->version << "\n"); + passes.push_back(Pass(pass->name, pass->addPass)); + } } - - if (result == TP_SUCCESS) - return TP_SUCCESS; - return llvm::createStringError( - llvm::Twine("Failed to retrieve plugin pass handles, error code: ") + - std::to_string(result)); + return passes; } -llvm::Expected -TritonPlugin::getPassHandles(std::vector &passNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-passes api symbol is present, bail as - // if there are 0 passes if not - intptr_t isPassPluginSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES); - if (!isPassPluginSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumeratePassesAPI, passNames); +void TritonPlugin::registerPasses() const { + if (!info->passes && info->numPasses > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid pass pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Registering " << info->numPasses + << " passes for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + for (auto i = 0; i < info->numPasses; ++i) { + const auto &pass = info->passes[i]; + if (pass.registerPass) { + LLVM_DEBUG(llvm::dbgs() << "Registering pass " << pass.name << ":" + << pass.version << "\n"); + pass.registerPass(); + } + } } -llvm::Expected -TritonPlugin::getDialectHandles(std::vector &dialectNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-dialects api symbol is present, bail as - // if there are 0 dialects if not - intptr_t isDialectPluginSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS); - if (!isDialectPluginSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumerateDialectsAPI, dialectNames); +void TritonPlugin::registerDialects(DialectRegistry &dialectRegistry) const { + if (!info->dialects && info->numDialects > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid dialect pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Registering " << info->numDialects + << " dialects for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + for (auto i = 0; i < info->numDialects; ++i) { + const auto &dialect = info->dialects[i]; + if (dialect.registerDialect) { + LLVM_DEBUG(llvm::dbgs() << "Registering dialect " << dialect.name << ":" + << dialect.version << "\n"); + dialect.registerDialect(&dialectRegistry); + } + } } -llvm::Expected -TritonPlugin::getCustomOpHandles(std::vector &customOpNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-custom-ops api symbol is present, bail - // as if there are 0 custom ops if not - intptr_t isCustomOpSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_CUSTOMOPS); - if (!isCustomOpSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumerateCustomOpAPI, customOpNames); +const std::vector TritonPlugin::listOps() const { + if (!info->ops && info->numOps > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid custom op pointer in plugin '") + filename + + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Listing " << info->numOps + << " custom ops for plugin " << info->pluginName + << ":" << info->pluginVersion << "\n"); + + std::vector ops; + for (auto i = 0; i < info->numOps; ++i) { + const auto op = &info->ops[i]; + if (op->addOp) { + LLVM_DEBUG(llvm::dbgs() << "Listing custom op " << op->name << "\n"); + ops.push_back(Op(op->name, op->addOp)); + } + } + return ops; } -llvm::Expected -TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle, - const std::vector &args) { - if (auto Err = loadPlugin()) - return Err; - return checkAPIResult(addPassAPI(pm, passHandle, args), passHandle); -} +static std::vector plugins; +static bool pluginsLoaded = false; +const std::vector &mlir::triton::plugin::loadPlugins() { + if (pluginsLoaded) + return plugins; -llvm::Expected -TritonPlugin::registerPass(const char *passHandle) { - if (auto Err = loadPlugin()) - return Err; - return checkAPIResult(registerPassAPI(passHandle), passHandle); -} + // Bailing when libtriton symbols are not visible is done to prevent + // crashes caused by loading plugins that will never find their dependent + // symbols (which are hidden by libtriton). +#if !defined(TRITON_EXT_ENABLED) || TRITON_EXT_ENABLED == 0 + bool skipLoading = true; +#else + bool skipLoading = false; +#endif -llvm::Expected<::mlir::DialectPluginLibraryInfo> -TritonPlugin::getDialectPluginInfo(const char *dialectName) { - if (auto Err = loadPlugin()) - return Err; - return dialectPluginInfoAPI(dialectName); -} + if (const char *env = std::getenv("TRITON_PLUGIN_PATHS")) { + llvm::SmallVector paths; + llvm::StringRef(env).split(paths, ':'); + for (const auto &path : paths) { + if (skipLoading) { + llvm::errs() << "\n" + << "\n=================== WARNING =====================\n" + << "Triton will not load the following extension\n" + << "because it is not built with TRITON_EXT_ENABLED:\n" + << path + << "\n=================================================\n" + << "\n"; + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "Loading plugin from path: " << path << "\n"); + auto pluginOrErr = TritonPlugin::load(path.str()); + if (auto err = pluginOrErr.takeError()) { + llvm::Error wrappedErr = llvm::createStringError( + llvm::Twine("Failed to load plugin from path: ") + path + + ". Error: " + llvm::toString(std::move(err))); + llvm::reportFatalUsageError(std::move(wrappedErr)); + } + plugins.push_back(std::move(*pluginOrErr)); + } + } -llvm::Expected -TritonPlugin::addCustomOp(const char *handle, TritonOpBuilder &self, - std::vector &operands) { - if (auto Err = loadPlugin()) - return Err; - addCustomOpAPI(handle, self, operands); - return TP_SUCCESS; + pluginsLoaded = true; + return plugins; } -void loadPluginDialects(const std::string &filename, - mlir::DialectRegistry ®istry) { - TritonPlugin TP(filename); - - std::vector dialectNames; - if (auto result = TP.getDialectHandles(dialectNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (unsigned i = 0; i < dialectNames.size(); ++i) { - const char *dialectName = dialectNames.data()[i]; - auto result = TP.getDialectPluginInfo(dialectName); - if (!result) - llvm::report_fatal_error(result.takeError()); - ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; - dialectPluginInfo.registerDialectRegistryCallbacks(®istry); - } -} +#undef DEBUG_TYPE diff --git a/python/src/ir.cc b/python/src/ir.cc index 243410935e5b..79e237bdd0bb 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -231,37 +231,6 @@ py::list getTensorDescMetadata(ModuleOp &mod) { } // anonymous namespace -static void -registerCustomOps(py::class_ &TritonOpBuilderBinding, - const std::string &filename) { - TritonPlugin TP(filename); - std::vector customOpNames; - if (auto result = TP.getCustomOpHandles(customOpNames); !result) - throw TP.err2exp(result.takeError()); - - for (unsigned i = 0; i < customOpNames.size(); ++i) { - const char *customOpName = customOpNames.data()[i]; - - TritonOpBuilderBinding.def( - customOpName, - [customOpName](TritonOpBuilder &self, - std::vector &args) -> mlir::Value { - std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - TritonPlugin TP(filename); - - ::mlir::Value dst; - std::vector<::mlir::Value> values = {dst}; - llvm::copy(args, std::back_inserter(values)); - auto result = TP.addCustomOp(customOpName, self, values); - if (!result) - throw TP.err2exp(result.takeError()); - dst = values[0]; - return dst; - }); - } -} - /*****************************************************************************/ /* Python bindings for ir */ /*****************************************************************************/ @@ -369,10 +338,9 @@ void init_triton_ir(py::module &&m) { m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; - if (std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - !filename.empty()) { - loadPluginDialects(filename, registry); + // Register plugin dialects. + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + plugin.registerDialects(registry); } registry.insert args) { + op.addOp(self, args); + }); + } } py::class_(m, "pass_manager", py::module_local()) diff --git a/python/src/passes.cc b/python/src/passes.cc index bd6567030327..a74ba9327f13 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -104,29 +104,15 @@ void init_triton_passes_ttgpuir(py::module &&m) { } void init_plugin_passes(py::module &&m) { - std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - if (filename.empty()) - return; - - TritonPlugin TP(filename); - std::vector passNames; - if (auto result = TP.getPassHandles(passNames); !result) - throw TP.err2exp(result.takeError()); - - for (unsigned i = 0; i < passNames.size(); ++i) { - const char *passName = passNames.data()[i]; - - m.def( - passName, - [passName](mlir ::PassManager &pm, std::vector args) { - std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - TritonPlugin TP(filename); - if (auto result = TP.addPass(&pm, passName, args); !result) - throw TP.err2exp(result.takeError()); - }, - py::arg("pm"), py::arg("args") = std::vector()); + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + for (const auto &pass : plugin.listPasses()) { + m.def( + pass.name, + [pass](mlir::PassManager &pm, std::vector args) { + pass.addPass(&pm, args); + }, + py::arg("pm"), py::arg("args") = std::vector()); + } } } diff --git a/test/Plugins/test-dialect-plugin.mlir b/test/Plugins/test-dialect-plugin.mlir index 41a4d118d764..35ccbf4ef199 100644 --- a/test/Plugins/test-dialect-plugin.mlir +++ b/test/Plugins/test-dialect-plugin.mlir @@ -1,5 +1,5 @@ // RUN: LD_PRELOAD=%shlibdir/../plugins/libtriton.so \ -// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libMLIRDialectPlugin.so \ +// RUN: TRITON_PLUGIN_PATHS=%shlibdir/../plugins/libMLIRDialectPlugin.so \ // RUN: triton-opt \ // RUN: -split-input-file --convert-plugin-gpu-to-llvm --convert-triton-gpu-to-llvm %s | \ // RUN: FileCheck %s diff --git a/test/Plugins/test-plugin.mlir b/test/Plugins/test-plugin.mlir index 3d2b0ee50f9c..ec09469dc293 100644 --- a/test/Plugins/test-plugin.mlir +++ b/test/Plugins/test-plugin.mlir @@ -1,10 +1,10 @@ // RUN: LD_PRELOAD=%shlibdir/../plugins/libtriton.so \ -// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so \ +// RUN: TRITON_PLUGIN_PATHS=%shlibdir/../plugins/libTritonPluginsTestLib.so \ // RUN: triton-opt \ // RUN: -split-input-file -tritongpu-plugin %s | FileCheck %s --check-prefix=CHECK-PLUGIN // RUN: LD_PRELOAD=%shlibdir/../plugins/libtriton.so \ -// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so \ +// RUN: TRITON_PLUGIN_PATHS=%shlibdir/../plugins/libTritonPluginsTestLib.so \ // RUN: triton-opt \ // RUN: -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-NOFLAG