diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp index a2fa1baec953..bd4687fb91c7 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp @@ -3,6 +3,7 @@ #include "DialectPlugin/DialectPluginTypes.h" using namespace mlir; +using namespace mlir::triton; using namespace mlir::triton::plugin; #include "DialectPlugin/DialectPluginOpsDialect.cpp.inc" @@ -25,13 +26,18 @@ void DialectPluginDialect::initialize() { #include "DialectPlugin/DialectPluginDialect.h" #include "DialectPlugin/DialectPluginPasses.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Tools/Plugins/PassPlugin.h" #include "triton/Tools/PluginUtils.h" #include "llvm/Config/llvm-config.h" -using namespace mlir; +static const char *PLUGIN_NAME = "DialectPlugin"; +static const char *DIALECT_NAME = "DialectPlugin"; +static const char *PASS_NAME = "plugingpu_conversion"; +static const char *VERSION = "0.1.0"; -static void addTritonPluginPass(mlir::PassManager *pm) { +static void addTritonPluginPass(mlir::PassManager *pm, + const std::vector &args) { pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass()); } @@ -41,65 +47,38 @@ static void registerTritonPluginPass() { }); } -static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion"; -static std::unordered_map passMap = - {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map registryMap = { - {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; -static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; - -// Key APIs: - -TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { - std::string passNameStr(passName); - if (passMap.find(passNameStr) == passMap.end()) - return TP_GENERIC_FAILURE; - passMap[passNameStr](pm); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonRegisterPluginPass(const char *passName) { - std::string passNameStr(passName); - if (registryMap.find(passNameStr) == registryMap.end()) - return TP_GENERIC_FAILURE; - registryMap[passNameStr](); - return TP_SUCCESS; +static void registerTritonPluginDialect(DialectRegistry *registry) { + registry->insert(); + mlir::triton::plugin::registerpluginPasses(); } -TRITON_PLUGIN_API -tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { - if (!passCount) - return TP_GENERIC_FAILURE; - auto count = passMap.size(); - assert(count == registryMap.size() && - "Expected register and add passes map size to match"); - *passCount = count; - if (!passNames) - return TP_SUCCESS; - unsigned i = 0; - for (auto passName : passNamesTable) { - passNames[i] = passName; - } - return TP_SUCCESS; -} +static void addTritonPluginCustomOp(TritonOpBuilder &self, + std::vector &operands) { + ::mlir::Value &dst = operands[0]; + ::mlir::Value &src = operands[1]; -TRITON_PLUGIN_API -tritonEnumeratePluginDialects(uint32_t *dialectCount, - const char **dialectNames) { - *dialectCount = 1; - if (!dialectNames) - return TP_SUCCESS; - dialectNames[0] = "DialectPlugin"; - return TP_SUCCESS; + dst = self.create(src, src); + operands[0] = dst; } -TRITON_PLUGIN_API_TYPE(DialectPluginLibraryInfo) -tritonGetDialectPluginInfo(const char *name) { - return {MLIR_PLUGIN_API_VERSION, "DialectPlugin", LLVM_VERSION_STRING, - [](DialectRegistry *registry) { - registry->insert(); - mlir::triton::plugin::registerpluginPasses(); - }}; +TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { + static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass, + registerTritonPluginPass}; + static plugin::PassInfo passes[] = {pass}; + static plugin::DialectInfo dialect = {DIALECT_NAME, VERSION, + registerTritonPluginDialect}; + static plugin::DialectInfo dialects[] = {dialect}; + static plugin::OpInfo op = {"create_custom_op", addTritonPluginCustomOp}; + static plugin::OpInfo ops[] = {op}; + static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION, + PLUGIN_NAME, + VERSION, + passes, + 1, + dialects, + 1, + ops, + 1, + TRITON_VERSION}; + return &info; } diff --git a/examples/plugins/TritonPlugin.cpp b/examples/plugins/TritonPlugin.cpp index a714d18618f5..33a252f94b81 100644 --- a/examples/plugins/TritonPlugin.cpp +++ b/examples/plugins/TritonPlugin.cpp @@ -10,17 +10,27 @@ namespace mlir { namespace triton { namespace plugin { +#define GEN_PASS_DECL_TRITONGPUMLIRPLUGIN #define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN #include "Passes.h.inc" struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { + using TritonGPUMLIRPluginBase::TritonGPUMLIRPluginBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + + std::string name; + llvm::raw_string_ostream sstr(name); + sstr << "foo"; + if (num_warps != 4) + sstr << "_num_warps_" << num_warps; + mod.walk([&](FunctionOpInterface funcOp) { StringAttr funcNameAttr = funcOp.getNameAttr(); - funcOp.setName("foo"); + funcOp.setName(name); }); } }; @@ -29,8 +39,16 @@ struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { } // namespace triton } // namespace mlir -static void addTritonPluginPass(mlir::PassManager *pm) { - pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); +static void addTritonPluginPass(mlir::PassManager *pm, + const std::vector &args) { + if (args.empty()) { + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); + return; + } + + mlir::triton::plugin::TritonGPUMLIRPluginOptions opts; + opts.num_warps = std::atoi(args[0].c_str()); + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin((opts))); } static void registerTritonPluginPass() { @@ -39,46 +57,25 @@ static void registerTritonPluginPass() { }); } -static const char *ADD_PLUGIN_PASS_NAME = "add_plugin"; -static std::unordered_map passMap = - {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map registryMap = { - {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; -static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; +static const char *PLUGIN_NAME = "TritonPlugin"; +static const char *PASS_NAME = "add_plugin"; +static const char *VERSION = "0.1.0"; -// Key APIs: +using namespace mlir::triton; -TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { - std::string passNameStr(passName); - if (passMap.find(passNameStr) == passMap.end()) - return TP_GENERIC_FAILURE; - passMap[passNameStr](pm); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonRegisterPluginPass(const char *passName) { - std::string passNameStr(passName); - if (registryMap.find(passNameStr) == registryMap.end()) - return TP_GENERIC_FAILURE; - registryMap[passNameStr](); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { - if (!passCount) - return TP_GENERIC_FAILURE; - auto count = passMap.size(); - assert(count == registryMap.size() && - "Expected register and add passes map size to match"); - *passCount = count; - if (!passNames) - return TP_SUCCESS; - unsigned i = 0; - for (auto passName : passNamesTable) { - passNames[i++] = passName; - } - return TP_SUCCESS; +TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { + static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass, + registerTritonPluginPass}; + static plugin::PassInfo passes[] = {pass}; + static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION, + PLUGIN_NAME, + VERSION, + passes, + 1, + nullptr, + 0, + nullptr, + 0, + TRITON_VERSION}; + return &info; } diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index 5878af01bb91..6cde31819cdc 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -1,100 +1,201 @@ +// Defines the external and internal interface for Triton plugins. +// +// This is largely meant to follow the plugin pattern outlined in upstream MLIR +// ([DialectPlugin], [PassPlugin]); use those as references for further +// additions. +// +// [DialectPlugin]: +// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/DialectPlugin.h +// [PassPlugin]: +// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h + #ifndef TRITON_PLUGIN_UTILS_H #define TRITON_PLUGIN_UTILS_H +#include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/PassManager.h" #include "mlir/Tools/Plugins/DialectPlugin.h" +#include "python/src/ir.h" +#include "triton/Version.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/Error.h" #include +#include + +/// Identifies the API version understood by this plugin. +/// +/// This version should be incremented for ABI-breaking changes in the structs +/// below; we check this version when loading a new \c TritonPlugin. See +/// similar: [MLIR_PLUGIN_API_VERSION]. +/// +/// [MLIR_PLUGIN_API_VERSION]: +/// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h#L32 +#define TRITON_PLUGIN_API_VERSION 2 + +/// Use this helper macro on the public entry point for a Triton plugin. +#define TRITON_PLUGIN_API extern "C" __attribute__((visibility("default"))) + +namespace mlir::triton::plugin { + +// Types for plugin callback functions. +using AddPassCallback = void (*)(mlir::PassManager *, + const std::vector &); +using RegisterPassCallback = void (*)(); +using RegisterDialectCallback = void (*)(mlir::DialectRegistry *); +using AddOpCallback = void (*)(TritonOpBuilder &, std::vector &); + +/// Information provided by a plugin for loading its passes. +struct PassInfo { + const char *name; + const char *version; + AddPassCallback addPass; + RegisterPassCallback registerPass; +}; -extern "C" { -enum TritonPluginResult { - TP_SUCCESS = 0, - TP_GENERIC_FAILURE = 1, +/// Information provided by a plugin for loading its dialects. +struct DialectInfo { + const char *name; + const char *version; + RegisterDialectCallback registerDialect; }; + +/// Information provided by a plugin for loading its custom ops. +struct OpInfo { + const char *name; + AddOpCallback addOp; }; -#define TRITON_PLUGIN_API \ - extern "C" __attribute__((visibility("default"))) TritonPluginResult -#define TRITON_PLUGIN_API_TYPE(_TYPE) \ - extern "C" __attribute__((visibility("default"))) _TYPE -struct TritonPlugin { - TritonPlugin() = delete; - TritonPlugin(std::string filename) : filename(filename) {} +/// Container for all plugin information; this is returned by the plugin +/// library's public entry point, @ref tritonGetPluginInfo. +struct PluginInfo { + /// The API version used by this plugin, see \c TRITON_PLUGIN_API_VERSION. + uint32_t apiVersion; -public: - llvm::Error checkLibraryValid(const std::string &error) const; - static constexpr char ENUMERATE_PASSES[] = "tritonEnumeratePluginPasses"; - static constexpr char ENUMERATE_DIALECTS[] = "tritonEnumeratePluginDialects"; - static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo"; - static constexpr char ADD_PASS[] = "tritonAddPluginPass"; - static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass"; + /// A meaningful name of the plugin. + const char *pluginName; + /// The version of the plugin. + const char *pluginVersion; -private: - using EnumeratePyBindHandlesType = - std::function; - using EnumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, - const char **); - - using AddPassType = - std::function; - using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, - const char *); - - using RegisterPassType = std::function; - using RegisterPassCType = TritonPluginResult (*)(const char *); - - using DialectPluginInfoType = - std::function<::mlir::DialectPluginLibraryInfo(const char *)>; - using DialectPluginInfoCType = - ::mlir::DialectPluginLibraryInfo (*)(const char *); - - llvm::Expected getAddressOfSymbol(const std::string &symbol) const; - - template - llvm::Expected getAPI(const std::string &symbol) const { - llvm::Expected getDetailsFn = getAddressOfSymbol(symbol); - if (auto Err = getDetailsFn.takeError()) { - return Err; - } - auto func = reinterpret_cast(*getDetailsFn); - return func; - } - - llvm::Expected checkAPIResult(TritonPluginResult result, - const char *handle) const; - llvm::Expected - enumeratePyBindHandles(EnumeratePyBindHandlesType &enumeratePyBindHandles, - std::vector &passNames); + /// The list of passes. + PassInfo *passes; + size_t numPasses; + + /// The list of dialects. + DialectInfo *dialects; + size_t numDialects; + + /// The list of custom ops. + OpInfo *ops; + size_t numOps; + + /// Triton Version + const char *tritonVersion; +}; + +/// A helper structure for storing information about a pass registered by a +/// plugin. +struct Pass { + Pass(const char *name, AddPassCallback addPass) + : name(name), addPass(addPass) {} + const char *name; + const AddPassCallback addPass; +}; + +/// A helper structure for storing information about a pass registered by a +/// plugin. +struct Op { + Op(const char *name, AddOpCallback addOp) : name(name), addOp(addOp) {} + + const char *name; + const AddOpCallback addOp; +}; + +/// A loaded Triton plugin. +/// +/// An instance of this class wraps a loaded dialect plugin and gives access +/// to its interface defined by the \c PluginInfo it exposes. +class TritonPlugin { public: - std::runtime_error err2exp(llvm::Error Err); + /// Attempts to load a Triton plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or + /// loaded, there is no public entry point, or the plugin implements the + /// wrong API version. + static llvm::Expected load(const std::string &filename); - llvm::Error loadPlugin(); + /// Get the filename of the loaded plugin. + llvm::StringRef getFilename() const { return filename; } - llvm::Expected - getPassHandles(std::vector &handles); + /// Get the plugin name. + llvm::StringRef getPluginName() const { return info->pluginName; } - llvm::Expected - getDialectHandles(std::vector &handles); + /// Get the plugin version. + llvm::StringRef getPluginVersion() const { return info->pluginVersion; } - llvm::Expected addPass(mlir::PassManager *pm, - const char *passHandle); + /// Get the plugin API version. + uint32_t getAPIVersion() const { return info->apiVersion; } - llvm::Expected registerPass(const char *passHandle); + /// List the available passes; this allows us invoke the \c AddPassCallback + /// while knowing the pass name. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + const std::vector listPasses() const; - llvm::Expected<::mlir::DialectPluginLibraryInfo> - getDialectPluginInfo(const char *dialectName); + /// Invoke the \c RegisterPassCallback for each pass registered in this + /// plugin. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + void registerPasses() const; + + /// Invoke the \c RegisterDialectCallback for each dialect registered in + /// this plugin. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + void registerDialects(DialectRegistry &dialectRegistry) const; + + /// List the custom operations; this allows us invoke the \c + /// AddOpCallback while knowing the operation name. This function will crash + /// with an LLVM usage error if the plugin provides invalid \c PluginInfo. + const std::vector listOps() const; private: - std::string filename = ""; - mutable llvm::sys::DynamicLibrary library; - EnumeratePyBindHandlesType enumeratePassesAPI; - EnumeratePyBindHandlesType enumerateDialectsAPI; - AddPassType addPassAPI; - RegisterPassType registerPassAPI; - DialectPluginInfoType dialectPluginInfoAPI; - bool isLoaded = false; + TritonPlugin(const std::string &filename, + const llvm::sys::DynamicLibrary &library) + : filename(filename), library(library), info() {} + + std::string filename; + llvm::sys::DynamicLibrary library; + PluginInfo *info; }; +/// Load all plugins specified in the `TRITON_PLUGIN_PATHS` environment +/// variable. This variable should contain a colon-separated list of paths to +/// plugin shared libraries. +/// +/// \returns Returns the list of successfully loaded plugins. If any plugin +/// fails to load, it crashes with an LLVM usage error. +const std::vector &loadPlugins(); + +} // namespace mlir::triton::plugin + +/// The public entry point for loading a Triton plugin. +/// +/// When a plugin is loaded by the driver, Triton will call this entry point to +/// obtain information about the plugin and how to load it. This function must +/// to be implemented by the plugin. +/// +/// Triton expects this function to return a pointer to a valid \c PluginInfo +/// struct. Because plugins are loaded in-process permanently, the \c PluginInfo +/// struct has a lifetime spanning the duration of the program; thus, no +/// deallocation function is required from the plugin. As an extra precaution +/// against leaks, return a pointer to a static struct: +/// +/// ``` +/// mlir::triton::plugin::PluginInfo *tritonGetPluginInfo() { +/// static mlir::triton::plugin::PluginInfo info = { ... }; +/// return &info; +/// } +/// ``` +extern "C" mlir::triton::plugin::PluginInfo *LLVM_ATTRIBUTE_WEAK +tritonGetPluginInfo(); + #endif // TRITON_PLUGIN_UTILS_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 4952a3c95d6f..8c0d8611cf2f 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -45,7 +46,8 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", - "TRITON_PASS_PLUGIN_PATH", + "TRITON_PLUGIN_PATHS", + "TRITON_PLUGIN_VERSION_CHECK", "TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT", "TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY", "TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY", diff --git a/include/triton/Version.h.in b/include/triton/Version.h.in new file mode 100644 index 000000000000..66219ab3a6ea --- /dev/null +++ b/include/triton/Version.h.in @@ -0,0 +1,6 @@ +#ifndef TRITON_VERSION_H +#define TRITON_VERSION_H + +#define TRITON_VERSION "@TRITON_VERSION@" + +#endif // TRITON_VERSION_H diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp index 64349900575a..05bbcc89fb09 100644 --- a/lib/Tools/PluginUtils.cpp +++ b/lib/Tools/PluginUtils.cpp @@ -1,162 +1,192 @@ #include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" -llvm::Error TritonPlugin::checkLibraryValid(const std::string &error) const { - if (!library.isValid()) { - auto msg = llvm::Twine("Failed to load plugin library: " + error + "\n"); - return llvm::createStringError(msg); - } - return llvm::Error::success(); -} +#define DEBUG_TYPE "triton-plugins" -llvm::Expected -TritonPlugin::getAddressOfSymbol(const std::string &symbol) const { - if (auto isValid = checkLibraryValid("not loaded")) - return isValid; - intptr_t getDetailsFn = (intptr_t)library.getAddressOfSymbol(symbol.c_str()); - if (!getDetailsFn) { - auto msg = llvm::Twine("Failed to get symbol: " + symbol + "\n"); - return llvm::createStringError(msg); - } - return getDetailsFn; -} +using namespace mlir::triton::plugin; -llvm::Expected -TritonPlugin::checkAPIResult(TritonPluginResult result, - const char *handle) const { - if (result == TP_SUCCESS) - return TP_SUCCESS; - std::string msg; - llvm::raw_string_ostream os(msg); - os << "Failed to add/register plugin pass (" << handle - << ") to pass manager, error code: " << result; - return llvm::createStringError(msg); -} +static bool isTritonAndPluginsVersionsMatch(const std::string &pluginVersion) { + // Here, if TRITON_PLUGIN_VERSION_CHECK is unset, then we simply do a default + // version check. However, if it is set then we either do a full (git hash) + // check or we skip all checking. + auto doCheck = + mlir::triton::tools::isEnvValueBool("TRITON_PLUGIN_VERSION_CHECK"); -std::runtime_error TritonPlugin::err2exp(llvm::Error Err) { - std::string msg; - llvm::raw_string_ostream os(msg); - os << Err; - return std::runtime_error(msg); -} + // Skip check when TRITON_PLUGIN_VERSION_CHECK is set false + if (doCheck.has_value() && !doCheck.value()) + return true; + + // Check full version string when TRITON_PLUGIN_VERSION_CHECK is set true + if (doCheck.has_value() && doCheck.value()) + return pluginVersion == TRITON_VERSION; -llvm::Error TritonPlugin::loadPlugin() { - if (isLoaded) - return llvm::Error::success(); + // Do partial release version check when TRITON_PLUGIN_VERSION_CHECK unset + assert(!doCheck.has_value() && "Expected TRITON_PLUGIN_VERSION_CHECK unset"); + return llvm::StringRef(pluginVersion).split('+').first == + llvm::StringRef(TRITON_VERSION).split('+').first; +} +llvm::Expected TritonPlugin::load(const std::string &filename) { std::string error; - library = + auto library = llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); - if (auto isValid = checkLibraryValid(error)) - return isValid; - - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES)) { - auto enumeratePassesAPIOrErr = - getAPI( - ENUMERATE_PASSES); - auto addPassAPIOrErr = getAPI(ADD_PASS); - auto registerPassAPIOrErr = - getAPI(REGISTER_PASS); - - if (auto Err = enumeratePassesAPIOrErr.takeError()) - return Err; - if (auto Err = addPassAPIOrErr.takeError()) - return Err; - if (auto Err = registerPassAPIOrErr.takeError()) - return Err; - - addPassAPI = *addPassAPIOrErr; - registerPassAPI = *registerPassAPIOrErr; - enumeratePassesAPI = *enumeratePassesAPIOrErr; - } + if (!library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + filename + "': " + error, + llvm::inconvertibleErrorCode()); + + TritonPlugin plugin{filename, library}; + + // tritonGetPluginInfo should be resolved to the definition from the + // plugin we are currently loading. + intptr_t getInfoFn = + (intptr_t)library.getAddressOfSymbol("tritonGetPluginInfo"); + if (!getInfoFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + filename + "'.", + llvm::inconvertibleErrorCode()); + + plugin.info = reinterpret_cast(getInfoFn)(); + + if (plugin.info->apiVersion != TRITON_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + filename + "'. Got version " + + Twine(plugin.info->apiVersion) + ", supported version is " + + Twine(TRITON_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + if (!isTritonAndPluginsVersionsMatch(plugin.info->tritonVersion)) + return llvm::make_error( + Twine("Wrong TRITON version on plugin '") + filename + + "'. Got version " + Twine(plugin.info->tritonVersion) + + ", supported version is " + Twine(TRITON_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + return plugin; +} - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS)) { - auto enumerateDialectsAPIOrErr = - getAPI( - ENUMERATE_DIALECTS); - auto dialectPluginInfoAPIOrErr = - getAPI( - DIALECT_PLUGININFO); - - if (auto Err = enumerateDialectsAPIOrErr.takeError()) - return Err; - if (auto Err = dialectPluginInfoAPIOrErr.takeError()) - return Err; - enumerateDialectsAPI = *enumerateDialectsAPIOrErr; - dialectPluginInfoAPI = *dialectPluginInfoAPIOrErr; +const std::vector TritonPlugin::listPasses() const { + if (!info->passes && info->numPasses > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid pass pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Listing " << info->numPasses + << " passes for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + std::vector passes; + for (auto i = 0; i < info->numPasses; ++i) { + const auto pass = &info->passes[i]; + if (pass->addPass) { + LLVM_DEBUG(llvm::dbgs() << "Listing pass " << pass->name << ":" + << pass->version << "\n"); + passes.push_back(Pass(pass->name, pass->addPass)); + } } - - isLoaded = true; - return llvm::Error::success(); + return passes; } -llvm::Expected TritonPlugin::enumeratePyBindHandles( - EnumeratePyBindHandlesType &enumeratePyBindHandles, - std::vector &handles) { - if (auto Err = loadPlugin()) - return Err; - - uint32_t passCount = 0; - handles.clear(); - auto result = enumeratePyBindHandles(&passCount, nullptr); - if (result == TP_SUCCESS) { - if (passCount == 0) - return TP_SUCCESS; - - handles.resize(passCount); - result = enumeratePyBindHandles(&passCount, handles.data()); +void TritonPlugin::registerPasses() const { + if (!info->passes && info->numPasses > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid pass pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Registering " << info->numPasses + << " passes for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + for (auto i = 0; i < info->numPasses; ++i) { + const auto &pass = info->passes[i]; + if (pass.registerPass) { + LLVM_DEBUG(llvm::dbgs() << "Registering pass " << pass.name << ":" + << pass.version << "\n"); + pass.registerPass(); + } } - - if (result == TP_SUCCESS) - return TP_SUCCESS; - std::string msg; - llvm::raw_string_ostream os(msg); - os << "Failed to retrive plugin pass handles, error code: " << result; - return llvm::createStringError(msg); } -llvm::Expected -TritonPlugin::getPassHandles(std::vector &passNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-passes api symbol is present, bail as - // if there are 0 passes if not - intptr_t isPassPluginSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES); - if (!isPassPluginSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumeratePassesAPI, passNames); +void TritonPlugin::registerDialects(DialectRegistry &dialectRegistry) const { + if (!info->dialects && info->numDialects > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid dialect pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Registering " << info->numDialects + << " dialects for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + for (auto i = 0; i < info->numDialects; ++i) { + const auto &dialect = info->dialects[i]; + if (dialect.registerDialect) { + LLVM_DEBUG(llvm::dbgs() << "Registering dialect " << dialect.name << ":" + << dialect.version << "\n"); + dialect.registerDialect(&dialectRegistry); + } + } } -llvm::Expected -TritonPlugin::getDialectHandles(std::vector &dialectNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-dialects api symbol is present, bail as - // if there are 0 dialects if not - intptr_t isDialectPluginSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS); - if (!isDialectPluginSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumerateDialectsAPI, dialectNames); +const std::vector TritonPlugin::listOps() const { + if (!info->ops && info->numOps > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid custom op pointer in plugin '") + filename + + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Listing " << info->numOps + << " custom ops for plugin " << info->pluginName + << ":" << info->pluginVersion << "\n"); + + std::vector ops; + for (auto i = 0; i < info->numOps; ++i) { + const auto op = &info->ops[i]; + if (op->addOp) { + LLVM_DEBUG(llvm::dbgs() << "Listing custom op " << op->name << "\n"); + ops.push_back(Op(op->name, op->addOp)); + } + } + return ops; } -llvm::Expected -TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) { - if (auto Err = loadPlugin()) - return Err; - return checkAPIResult(addPassAPI(pm, passHandle), passHandle); -} +static std::vector plugins; +static bool pluginsLoaded = false; +const std::vector &mlir::triton::plugin::loadPlugins() { + if (pluginsLoaded) + return plugins; + + // Bailing when libtriton symbols are not visible is done to prevent + // crashes caused by loading plugins that will never find their dependent + // symbols (which are hidden by libtriton). +#if !defined(TRITON_EXT_ENABLED) || TRITON_EXT_ENABLED == 0 + bool skipLoading = true; +#else + bool skipLoading = false; +#endif + + if (const char *env = std::getenv("TRITON_PLUGIN_PATHS")) { + llvm::SmallVector paths; + llvm::StringRef(env).split(paths, ':'); + for (const auto &path : paths) { + if (skipLoading) { + llvm::errs() << "\n" + << "\n=================== WARNING =====================\n" + << "Triton will not load the following extension\n" + << "because it is not built with TRITON_EXT_ENABLED:\n" + << path + << "\n=================================================\n" + << "\n"; + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "Loading plugin from path: " << path << "\n"); + auto pluginOrErr = TritonPlugin::load(path.str()); + if (auto err = pluginOrErr.takeError()) { + llvm::Error wrappedErr = llvm::createStringError( + llvm::Twine("Failed to load plugin from path: ") + path + + ". Error: " + llvm::toString(std::move(err))); + llvm::reportFatalUsageError(std::move(wrappedErr)); + } + plugins.push_back(std::move(*pluginOrErr)); + } + } -llvm::Expected -TritonPlugin::registerPass(const char *passHandle) { - if (auto Err = loadPlugin()) - return Err; - return checkAPIResult(registerPassAPI(passHandle), passHandle); + pluginsLoaded = true; + return plugins; } -llvm::Expected<::mlir::DialectPluginLibraryInfo> -TritonPlugin::getDialectPluginInfo(const char *dialectName) { - if (auto Err = loadPlugin()) - return Err; - return dialectPluginInfoAPI(dialectName); -} +#undef DEBUG_TYPE