diff --git a/CMakeLists.txt b/CMakeLists.txt index ac153df137be..c47bb0126ccc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,9 +20,10 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF) option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON) option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON) option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) -option(LLVM_BUILD_SHARED_LIBS - "Build all libraries as shared libraries instead of static" OFF) +option(TRITON_OFFLINE_BUILD "Build without downloading dependencies" OFF) +option(TRITON_EXT_ENABLED "Build with default visibility for Triton+LLVM symbol exposure to plugin extensions" OFF) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") +set(TRITON_VERSION "" CACHE STRING "Triton version string (passed from setup.py)") if(TRITON_BUILD_WITH_CCACHE) find_program(CCACHE_PROGRAM ccache) @@ -62,7 +63,6 @@ else() set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL") - set(LLVM_BUILD_SHARED_LIBS "0") endif() # Default build type @@ -86,6 +86,12 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS") endif() +# Regenerate Triton Version Header +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/include/triton/Version.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/include/triton/Version.h" + @ONLY +) # ######### # LLVM @@ -144,7 +150,17 @@ endfunction() # Disable warnings that show up in external code (gtest;pybind11) if(NOT MSVC) set(TRITON_DISABLE_EH_RTTI_FLAGS "$<$:-fno-exceptions;-fno-rtti>") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden") + + if(NOT TRITON_EXT_ENABLED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") + else() + # Inform plugin loader if that libtriton is compiled with visibility + # so that they will not proceed with loading plugins if that will + # crash a Triton not compiled with visibility due to missing symbols. + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTRITON_EXT_ENABLED=1") + endif() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default ") + else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4624 /wd4715 /wd4530") endif() @@ -293,6 +309,13 @@ if(TRITON_BUILD_PYTHON_MODULE) # Link triton with its dependencies target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES}) + + + # Do not propagate libraries that libtriton depends on. This ensures that + # targets that link against libtriton do not accidentally link in their own + # copies of core Triton code and LLVM. + set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "") + if(WIN32) target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS}) set_target_properties(triton PROPERTIES SUFFIX ".pyd") @@ -315,9 +338,11 @@ if(TRITON_BUILD_PYTHON_MODULE) "${TRITON_WHEEL_DIR}/FileCheck" COPYONLY) + # Build plugins when building libtriton since they depend on libtriton. + add_subdirectory(examples/plugins) endif() -if (UNIX AND NOT APPLE) +if (UNIX AND NOT APPLE AND NOT TRITON_EXT_ENABLED) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL") endif() @@ -344,7 +369,6 @@ find_package(Threads REQUIRED) add_subdirectory(third_party/f2reduce) add_subdirectory(bin) add_subdirectory(test) -add_subdirectory(examples) if(TRITON_BUILD_UT) add_subdirectory(unittest) diff --git a/Makefile b/Makefile index dfcf1f545388..a81081ecb69b 100644 --- a/Makefile +++ b/Makefile @@ -43,11 +43,13 @@ test-unit: all $(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \ $(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py - TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \ + TRITON_PLUGIN_PATHS=python/triton/plugins/libTritonPluginsTestLib.so \ $(PYTEST) -vvv python/test/unit/plugins/test_plugin.py - TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \ + TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \ $(PYTEST) -s -vvv python/test/unit/plugins/test_dialect_plugin.py $(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon + TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \ + $(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py .PHONY: test-gluon test-gluon: all diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index a7d60d5dff66..b64da9cc9fab 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -146,32 +146,10 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::proton::gpu::registerScheduleBufferStorePass(); mlir::triton::proton::gpu::registerAddSchedBarriersPass(); - // Plugin passes - if (std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - !filename.empty()) { - - TritonPlugin TP(filename); - std::vector passNames; - if (auto result = TP.getPassHandles(passNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (const char *passName : passNames) - if (auto result = TP.registerPass(passName); !result) - llvm::report_fatal_error(result.takeError()); - - std::vector dialectNames; - if (auto result = TP.getDialectHandles(dialectNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (unsigned i = 0; i < dialectNames.size(); ++i) { - const char *dialectName = dialectNames.data()[i]; - auto result = TP.getDialectPluginInfo(dialectName); - if (!result) - llvm::report_fatal_error(result.takeError()); - ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; - dialectPluginInfo.registerDialectRegistryCallbacks(®istry); - } + // Register plugin passes and dialects. + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + plugin.registerPasses(); + plugin.registerDialects(registry); } registry.insert< diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt deleted file mode 100644 index 0e89371e07e6..000000000000 --- a/examples/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(plugins) diff --git a/examples/plugins/CMakeLists.txt b/examples/plugins/CMakeLists.txt index 6b4a14952b7b..e89bd44b41eb 100644 --- a/examples/plugins/CMakeLists.txt +++ b/examples/plugins/CMakeLists.txt @@ -24,7 +24,7 @@ foreach( plugin ${TRITON_PLUGIN_PASSES} ) TritonCanonicalizeIncGen TritonPluginsIncGen ) - target_link_libraries(${plugin} PRIVATE MLIRPass) + target_link_libraries(${plugin} PRIVATE triton) # CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python # build. It is empty if building directly from the root diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt index 2e0271800053..ff7c66a35070 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt @@ -20,10 +20,8 @@ add_mlir_dialect_library(MLIRDialectPlugin MLIRDialectPluginPassesIncGen LINK_LIBS PUBLIC - MLIRPass - LLVMSupport - MLIRSupport - TritonNVIDIAGPUToLLVM + triton + "$<$:-undefined dynamic_lookup>" ) diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp index a2fa1baec953..bd4687fb91c7 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp @@ -3,6 +3,7 @@ #include "DialectPlugin/DialectPluginTypes.h" using namespace mlir; +using namespace mlir::triton; using namespace mlir::triton::plugin; #include "DialectPlugin/DialectPluginOpsDialect.cpp.inc" @@ -25,13 +26,18 @@ void DialectPluginDialect::initialize() { #include "DialectPlugin/DialectPluginDialect.h" #include "DialectPlugin/DialectPluginPasses.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Tools/Plugins/PassPlugin.h" #include "triton/Tools/PluginUtils.h" #include "llvm/Config/llvm-config.h" -using namespace mlir; +static const char *PLUGIN_NAME = "DialectPlugin"; +static const char *DIALECT_NAME = "DialectPlugin"; +static const char *PASS_NAME = "plugingpu_conversion"; +static const char *VERSION = "0.1.0"; -static void addTritonPluginPass(mlir::PassManager *pm) { +static void addTritonPluginPass(mlir::PassManager *pm, + const std::vector &args) { pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass()); } @@ -41,65 +47,38 @@ static void registerTritonPluginPass() { }); } -static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion"; -static std::unordered_map passMap = - {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map registryMap = { - {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; -static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; - -// Key APIs: - -TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { - std::string passNameStr(passName); - if (passMap.find(passNameStr) == passMap.end()) - return TP_GENERIC_FAILURE; - passMap[passNameStr](pm); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonRegisterPluginPass(const char *passName) { - std::string passNameStr(passName); - if (registryMap.find(passNameStr) == registryMap.end()) - return TP_GENERIC_FAILURE; - registryMap[passNameStr](); - return TP_SUCCESS; +static void registerTritonPluginDialect(DialectRegistry *registry) { + registry->insert(); + mlir::triton::plugin::registerpluginPasses(); } -TRITON_PLUGIN_API -tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { - if (!passCount) - return TP_GENERIC_FAILURE; - auto count = passMap.size(); - assert(count == registryMap.size() && - "Expected register and add passes map size to match"); - *passCount = count; - if (!passNames) - return TP_SUCCESS; - unsigned i = 0; - for (auto passName : passNamesTable) { - passNames[i] = passName; - } - return TP_SUCCESS; -} +static void addTritonPluginCustomOp(TritonOpBuilder &self, + std::vector &operands) { + ::mlir::Value &dst = operands[0]; + ::mlir::Value &src = operands[1]; -TRITON_PLUGIN_API -tritonEnumeratePluginDialects(uint32_t *dialectCount, - const char **dialectNames) { - *dialectCount = 1; - if (!dialectNames) - return TP_SUCCESS; - dialectNames[0] = "DialectPlugin"; - return TP_SUCCESS; + dst = self.create(src, src); + operands[0] = dst; } -TRITON_PLUGIN_API_TYPE(DialectPluginLibraryInfo) -tritonGetDialectPluginInfo(const char *name) { - return {MLIR_PLUGIN_API_VERSION, "DialectPlugin", LLVM_VERSION_STRING, - [](DialectRegistry *registry) { - registry->insert(); - mlir::triton::plugin::registerpluginPasses(); - }}; +TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { + static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass, + registerTritonPluginPass}; + static plugin::PassInfo passes[] = {pass}; + static plugin::DialectInfo dialect = {DIALECT_NAME, VERSION, + registerTritonPluginDialect}; + static plugin::DialectInfo dialects[] = {dialect}; + static plugin::OpInfo op = {"create_custom_op", addTritonPluginCustomOp}; + static plugin::OpInfo ops[] = {op}; + static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION, + PLUGIN_NAME, + VERSION, + passes, + 1, + dialects, + 1, + ops, + 1, + TRITON_VERSION}; + return &info; } diff --git a/examples/plugins/Passes.td b/examples/plugins/Passes.td index a8007a09e84b..8971e5bdc678 100644 --- a/examples/plugins/Passes.td +++ b/examples/plugins/Passes.td @@ -5,5 +5,12 @@ include "mlir/Pass/PassBase.td" def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> { let summary = "Triton MLIR Plugin Pass"; + + let options = [ + Option<"num_warps", "num-warps", + "int32_t", /*default*/"4", + "Number of warps">, + ]; + } #endif diff --git a/examples/plugins/README.md b/examples/plugins/README.md index c5615e9ae929..ebd511bef5b0 100644 --- a/examples/plugins/README.md +++ b/examples/plugins/README.md @@ -17,8 +17,8 @@ long as the libtriton.so is linked to the plugin and the Triton include passes a ## Example 1: Developing a custom pass and running triton-opt to inspect the modified IR ``` bash -export LLVM_BUILD_SHARED_LIBS=1; make dev-install-llvm -TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir +export TRITON_EXT_ENABLED=1; make dev-install-llvm +TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir ``` ``` MLIR module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} { @@ -85,7 +85,7 @@ Running same code but loading the plugin library also produces the same results pass manager it is not inserted into the compiler pass pipeline: ``` bash -TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py ``` ``` MLIR @@ -151,7 +151,7 @@ if __name__ == '__main__': h = kernel[grid](BLOCK_SIZE=1024) print(h.asm["ttgir"]) - if "TRITON_PASS_PLUGIN_PATH" in os.environ: + if "TRITON_PLUGIN_PATHS" in os.environ: knobs.runtime.add_stages_inspection_hook = inspect_stages_hook h = kernel[grid](BLOCK_SIZE=1024) print(h.asm["ttgir"]) @@ -163,7 +163,7 @@ if __name__ == '__main__': ``` ``` bash -TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py +TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py ``` Shows the pass ran and modified the kernel name but only after the hook is set. Any kernels before the hook or after the hook is unset are left unchanged. @@ -330,7 +330,7 @@ if __name__ == '__main__': if "add_loop_unroll" in line: outfile.write("\n passes.plugin.add_plugin(pm)\n") outfile.write(line) - if "TRITON_PASS_PLUGIN_PATH" in os.environ: + if "TRITON_PLUGIN_PATHS" in os.environ: knobs.runtime.add_stages_inspection_hook = override_stages h = kernel2[grid](BLOCK_SIZE=1024) print(h.asm["ttgir"]) diff --git a/examples/plugins/TritonPlugin.cpp b/examples/plugins/TritonPlugin.cpp index a714d18618f5..33a252f94b81 100644 --- a/examples/plugins/TritonPlugin.cpp +++ b/examples/plugins/TritonPlugin.cpp @@ -10,17 +10,27 @@ namespace mlir { namespace triton { namespace plugin { +#define GEN_PASS_DECL_TRITONGPUMLIRPLUGIN #define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN #include "Passes.h.inc" struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { + using TritonGPUMLIRPluginBase::TritonGPUMLIRPluginBase; + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + + std::string name; + llvm::raw_string_ostream sstr(name); + sstr << "foo"; + if (num_warps != 4) + sstr << "_num_warps_" << num_warps; + mod.walk([&](FunctionOpInterface funcOp) { StringAttr funcNameAttr = funcOp.getNameAttr(); - funcOp.setName("foo"); + funcOp.setName(name); }); } }; @@ -29,8 +39,16 @@ struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase { } // namespace triton } // namespace mlir -static void addTritonPluginPass(mlir::PassManager *pm) { - pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); +static void addTritonPluginPass(mlir::PassManager *pm, + const std::vector &args) { + if (args.empty()) { + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin()); + return; + } + + mlir::triton::plugin::TritonGPUMLIRPluginOptions opts; + opts.num_warps = std::atoi(args[0].c_str()); + pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin((opts))); } static void registerTritonPluginPass() { @@ -39,46 +57,25 @@ static void registerTritonPluginPass() { }); } -static const char *ADD_PLUGIN_PASS_NAME = "add_plugin"; -static std::unordered_map passMap = - {{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}}; -static std::unordered_map registryMap = { - {ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}}; -static std::vector passNamesTable = {ADD_PLUGIN_PASS_NAME}; +static const char *PLUGIN_NAME = "TritonPlugin"; +static const char *PASS_NAME = "add_plugin"; +static const char *VERSION = "0.1.0"; -// Key APIs: +using namespace mlir::triton; -TRITON_PLUGIN_API -tritonAddPluginPass(mlir::PassManager *pm, const char *passName) { - std::string passNameStr(passName); - if (passMap.find(passNameStr) == passMap.end()) - return TP_GENERIC_FAILURE; - passMap[passNameStr](pm); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonRegisterPluginPass(const char *passName) { - std::string passNameStr(passName); - if (registryMap.find(passNameStr) == registryMap.end()) - return TP_GENERIC_FAILURE; - registryMap[passNameStr](); - return TP_SUCCESS; -} - -TRITON_PLUGIN_API -tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) { - if (!passCount) - return TP_GENERIC_FAILURE; - auto count = passMap.size(); - assert(count == registryMap.size() && - "Expected register and add passes map size to match"); - *passCount = count; - if (!passNames) - return TP_SUCCESS; - unsigned i = 0; - for (auto passName : passNamesTable) { - passNames[i++] = passName; - } - return TP_SUCCESS; +TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { + static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass, + registerTritonPluginPass}; + static plugin::PassInfo passes[] = {pass}; + static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION, + PLUGIN_NAME, + VERSION, + passes, + 1, + nullptr, + 0, + nullptr, + 0, + TRITON_VERSION}; + return &info; } diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index 5878af01bb91..6cde31819cdc 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -1,100 +1,201 @@ +// Defines the external and internal interface for Triton plugins. +// +// This is largely meant to follow the plugin pattern outlined in upstream MLIR +// ([DialectPlugin], [PassPlugin]); use those as references for further +// additions. +// +// [DialectPlugin]: +// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/DialectPlugin.h +// [PassPlugin]: +// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h + #ifndef TRITON_PLUGIN_UTILS_H #define TRITON_PLUGIN_UTILS_H +#include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/PassManager.h" #include "mlir/Tools/Plugins/DialectPlugin.h" +#include "python/src/ir.h" +#include "triton/Version.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/Error.h" #include +#include + +/// Identifies the API version understood by this plugin. +/// +/// This version should be incremented for ABI-breaking changes in the structs +/// below; we check this version when loading a new \c TritonPlugin. See +/// similar: [MLIR_PLUGIN_API_VERSION]. +/// +/// [MLIR_PLUGIN_API_VERSION]: +/// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h#L32 +#define TRITON_PLUGIN_API_VERSION 2 + +/// Use this helper macro on the public entry point for a Triton plugin. +#define TRITON_PLUGIN_API extern "C" __attribute__((visibility("default"))) + +namespace mlir::triton::plugin { + +// Types for plugin callback functions. +using AddPassCallback = void (*)(mlir::PassManager *, + const std::vector &); +using RegisterPassCallback = void (*)(); +using RegisterDialectCallback = void (*)(mlir::DialectRegistry *); +using AddOpCallback = void (*)(TritonOpBuilder &, std::vector &); + +/// Information provided by a plugin for loading its passes. +struct PassInfo { + const char *name; + const char *version; + AddPassCallback addPass; + RegisterPassCallback registerPass; +}; -extern "C" { -enum TritonPluginResult { - TP_SUCCESS = 0, - TP_GENERIC_FAILURE = 1, +/// Information provided by a plugin for loading its dialects. +struct DialectInfo { + const char *name; + const char *version; + RegisterDialectCallback registerDialect; }; + +/// Information provided by a plugin for loading its custom ops. +struct OpInfo { + const char *name; + AddOpCallback addOp; }; -#define TRITON_PLUGIN_API \ - extern "C" __attribute__((visibility("default"))) TritonPluginResult -#define TRITON_PLUGIN_API_TYPE(_TYPE) \ - extern "C" __attribute__((visibility("default"))) _TYPE -struct TritonPlugin { - TritonPlugin() = delete; - TritonPlugin(std::string filename) : filename(filename) {} +/// Container for all plugin information; this is returned by the plugin +/// library's public entry point, @ref tritonGetPluginInfo. +struct PluginInfo { + /// The API version used by this plugin, see \c TRITON_PLUGIN_API_VERSION. + uint32_t apiVersion; -public: - llvm::Error checkLibraryValid(const std::string &error) const; - static constexpr char ENUMERATE_PASSES[] = "tritonEnumeratePluginPasses"; - static constexpr char ENUMERATE_DIALECTS[] = "tritonEnumeratePluginDialects"; - static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo"; - static constexpr char ADD_PASS[] = "tritonAddPluginPass"; - static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass"; + /// A meaningful name of the plugin. + const char *pluginName; + /// The version of the plugin. + const char *pluginVersion; -private: - using EnumeratePyBindHandlesType = - std::function; - using EnumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *, - const char **); - - using AddPassType = - std::function; - using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, - const char *); - - using RegisterPassType = std::function; - using RegisterPassCType = TritonPluginResult (*)(const char *); - - using DialectPluginInfoType = - std::function<::mlir::DialectPluginLibraryInfo(const char *)>; - using DialectPluginInfoCType = - ::mlir::DialectPluginLibraryInfo (*)(const char *); - - llvm::Expected getAddressOfSymbol(const std::string &symbol) const; - - template - llvm::Expected getAPI(const std::string &symbol) const { - llvm::Expected getDetailsFn = getAddressOfSymbol(symbol); - if (auto Err = getDetailsFn.takeError()) { - return Err; - } - auto func = reinterpret_cast(*getDetailsFn); - return func; - } - - llvm::Expected checkAPIResult(TritonPluginResult result, - const char *handle) const; - llvm::Expected - enumeratePyBindHandles(EnumeratePyBindHandlesType &enumeratePyBindHandles, - std::vector &passNames); + /// The list of passes. + PassInfo *passes; + size_t numPasses; + + /// The list of dialects. + DialectInfo *dialects; + size_t numDialects; + + /// The list of custom ops. + OpInfo *ops; + size_t numOps; + + /// Triton Version + const char *tritonVersion; +}; + +/// A helper structure for storing information about a pass registered by a +/// plugin. +struct Pass { + Pass(const char *name, AddPassCallback addPass) + : name(name), addPass(addPass) {} + const char *name; + const AddPassCallback addPass; +}; + +/// A helper structure for storing information about a pass registered by a +/// plugin. +struct Op { + Op(const char *name, AddOpCallback addOp) : name(name), addOp(addOp) {} + + const char *name; + const AddOpCallback addOp; +}; + +/// A loaded Triton plugin. +/// +/// An instance of this class wraps a loaded dialect plugin and gives access +/// to its interface defined by the \c PluginInfo it exposes. +class TritonPlugin { public: - std::runtime_error err2exp(llvm::Error Err); + /// Attempts to load a Triton plugin from a given file. + /// + /// \returns Returns an error if either the library cannot be found or + /// loaded, there is no public entry point, or the plugin implements the + /// wrong API version. + static llvm::Expected load(const std::string &filename); - llvm::Error loadPlugin(); + /// Get the filename of the loaded plugin. + llvm::StringRef getFilename() const { return filename; } - llvm::Expected - getPassHandles(std::vector &handles); + /// Get the plugin name. + llvm::StringRef getPluginName() const { return info->pluginName; } - llvm::Expected - getDialectHandles(std::vector &handles); + /// Get the plugin version. + llvm::StringRef getPluginVersion() const { return info->pluginVersion; } - llvm::Expected addPass(mlir::PassManager *pm, - const char *passHandle); + /// Get the plugin API version. + uint32_t getAPIVersion() const { return info->apiVersion; } - llvm::Expected registerPass(const char *passHandle); + /// List the available passes; this allows us invoke the \c AddPassCallback + /// while knowing the pass name. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + const std::vector listPasses() const; - llvm::Expected<::mlir::DialectPluginLibraryInfo> - getDialectPluginInfo(const char *dialectName); + /// Invoke the \c RegisterPassCallback for each pass registered in this + /// plugin. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + void registerPasses() const; + + /// Invoke the \c RegisterDialectCallback for each dialect registered in + /// this plugin. This function will crash with an LLVM usage + /// error if the plugin provides invalid \c PluginInfo. + void registerDialects(DialectRegistry &dialectRegistry) const; + + /// List the custom operations; this allows us invoke the \c + /// AddOpCallback while knowing the operation name. This function will crash + /// with an LLVM usage error if the plugin provides invalid \c PluginInfo. + const std::vector listOps() const; private: - std::string filename = ""; - mutable llvm::sys::DynamicLibrary library; - EnumeratePyBindHandlesType enumeratePassesAPI; - EnumeratePyBindHandlesType enumerateDialectsAPI; - AddPassType addPassAPI; - RegisterPassType registerPassAPI; - DialectPluginInfoType dialectPluginInfoAPI; - bool isLoaded = false; + TritonPlugin(const std::string &filename, + const llvm::sys::DynamicLibrary &library) + : filename(filename), library(library), info() {} + + std::string filename; + llvm::sys::DynamicLibrary library; + PluginInfo *info; }; +/// Load all plugins specified in the `TRITON_PLUGIN_PATHS` environment +/// variable. This variable should contain a colon-separated list of paths to +/// plugin shared libraries. +/// +/// \returns Returns the list of successfully loaded plugins. If any plugin +/// fails to load, it crashes with an LLVM usage error. +const std::vector &loadPlugins(); + +} // namespace mlir::triton::plugin + +/// The public entry point for loading a Triton plugin. +/// +/// When a plugin is loaded by the driver, Triton will call this entry point to +/// obtain information about the plugin and how to load it. This function must +/// to be implemented by the plugin. +/// +/// Triton expects this function to return a pointer to a valid \c PluginInfo +/// struct. Because plugins are loaded in-process permanently, the \c PluginInfo +/// struct has a lifetime spanning the duration of the program; thus, no +/// deallocation function is required from the plugin. As an extra precaution +/// against leaks, return a pointer to a static struct: +/// +/// ``` +/// mlir::triton::plugin::PluginInfo *tritonGetPluginInfo() { +/// static mlir::triton::plugin::PluginInfo info = { ... }; +/// return &info; +/// } +/// ``` +extern "C" mlir::triton::plugin::PluginInfo *LLVM_ATTRIBUTE_WEAK +tritonGetPluginInfo(); + #endif // TRITON_PLUGIN_UTILS_H diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 4952a3c95d6f..d83d445c5601 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -45,7 +45,8 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_F32_DEFAULT", "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", - "TRITON_PASS_PLUGIN_PATH", + "TRITON_PLUGIN_PATHS", + "TRITON_PLUGIN_VERSION_CHECK", "TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT", "TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY", "TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY", diff --git a/include/triton/Version.h.in b/include/triton/Version.h.in new file mode 100644 index 000000000000..66219ab3a6ea --- /dev/null +++ b/include/triton/Version.h.in @@ -0,0 +1,6 @@ +#ifndef TRITON_VERSION_H +#define TRITON_VERSION_H + +#define TRITON_VERSION "@TRITON_VERSION@" + +#endif // TRITON_VERSION_H diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp index 64349900575a..05bbcc89fb09 100644 --- a/lib/Tools/PluginUtils.cpp +++ b/lib/Tools/PluginUtils.cpp @@ -1,162 +1,192 @@ #include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" -llvm::Error TritonPlugin::checkLibraryValid(const std::string &error) const { - if (!library.isValid()) { - auto msg = llvm::Twine("Failed to load plugin library: " + error + "\n"); - return llvm::createStringError(msg); - } - return llvm::Error::success(); -} +#define DEBUG_TYPE "triton-plugins" -llvm::Expected -TritonPlugin::getAddressOfSymbol(const std::string &symbol) const { - if (auto isValid = checkLibraryValid("not loaded")) - return isValid; - intptr_t getDetailsFn = (intptr_t)library.getAddressOfSymbol(symbol.c_str()); - if (!getDetailsFn) { - auto msg = llvm::Twine("Failed to get symbol: " + symbol + "\n"); - return llvm::createStringError(msg); - } - return getDetailsFn; -} +using namespace mlir::triton::plugin; -llvm::Expected -TritonPlugin::checkAPIResult(TritonPluginResult result, - const char *handle) const { - if (result == TP_SUCCESS) - return TP_SUCCESS; - std::string msg; - llvm::raw_string_ostream os(msg); - os << "Failed to add/register plugin pass (" << handle - << ") to pass manager, error code: " << result; - return llvm::createStringError(msg); -} +static bool isTritonAndPluginsVersionsMatch(const std::string &pluginVersion) { + // Here, if TRITON_PLUGIN_VERSION_CHECK is unset, then we simply do a default + // version check. However, if it is set then we either do a full (git hash) + // check or we skip all checking. + auto doCheck = + mlir::triton::tools::isEnvValueBool("TRITON_PLUGIN_VERSION_CHECK"); -std::runtime_error TritonPlugin::err2exp(llvm::Error Err) { - std::string msg; - llvm::raw_string_ostream os(msg); - os << Err; - return std::runtime_error(msg); -} + // Skip check when TRITON_PLUGIN_VERSION_CHECK is set false + if (doCheck.has_value() && !doCheck.value()) + return true; + + // Check full version string when TRITON_PLUGIN_VERSION_CHECK is set true + if (doCheck.has_value() && doCheck.value()) + return pluginVersion == TRITON_VERSION; -llvm::Error TritonPlugin::loadPlugin() { - if (isLoaded) - return llvm::Error::success(); + // Do partial release version check when TRITON_PLUGIN_VERSION_CHECK unset + assert(!doCheck.has_value() && "Expected TRITON_PLUGIN_VERSION_CHECK unset"); + return llvm::StringRef(pluginVersion).split('+').first == + llvm::StringRef(TRITON_VERSION).split('+').first; +} +llvm::Expected TritonPlugin::load(const std::string &filename) { std::string error; - library = + auto library = llvm::sys::DynamicLibrary::getPermanentLibrary(filename.c_str(), &error); - if (auto isValid = checkLibraryValid(error)) - return isValid; - - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES)) { - auto enumeratePassesAPIOrErr = - getAPI( - ENUMERATE_PASSES); - auto addPassAPIOrErr = getAPI(ADD_PASS); - auto registerPassAPIOrErr = - getAPI(REGISTER_PASS); - - if (auto Err = enumeratePassesAPIOrErr.takeError()) - return Err; - if (auto Err = addPassAPIOrErr.takeError()) - return Err; - if (auto Err = registerPassAPIOrErr.takeError()) - return Err; - - addPassAPI = *addPassAPIOrErr; - registerPassAPI = *registerPassAPIOrErr; - enumeratePassesAPI = *enumeratePassesAPIOrErr; - } + if (!library.isValid()) + return llvm::make_error( + Twine("Could not load library '") + filename + "': " + error, + llvm::inconvertibleErrorCode()); + + TritonPlugin plugin{filename, library}; + + // tritonGetPluginInfo should be resolved to the definition from the + // plugin we are currently loading. + intptr_t getInfoFn = + (intptr_t)library.getAddressOfSymbol("tritonGetPluginInfo"); + if (!getInfoFn) + return llvm::make_error( + Twine("Plugin entry point not found in '") + filename + "'.", + llvm::inconvertibleErrorCode()); + + plugin.info = reinterpret_cast(getInfoFn)(); + + if (plugin.info->apiVersion != TRITON_PLUGIN_API_VERSION) + return llvm::make_error( + Twine("Wrong API version on plugin '") + filename + "'. Got version " + + Twine(plugin.info->apiVersion) + ", supported version is " + + Twine(TRITON_PLUGIN_API_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + if (!isTritonAndPluginsVersionsMatch(plugin.info->tritonVersion)) + return llvm::make_error( + Twine("Wrong TRITON version on plugin '") + filename + + "'. Got version " + Twine(plugin.info->tritonVersion) + + ", supported version is " + Twine(TRITON_VERSION) + ".", + llvm::inconvertibleErrorCode()); + + return plugin; +} - if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS)) { - auto enumerateDialectsAPIOrErr = - getAPI( - ENUMERATE_DIALECTS); - auto dialectPluginInfoAPIOrErr = - getAPI( - DIALECT_PLUGININFO); - - if (auto Err = enumerateDialectsAPIOrErr.takeError()) - return Err; - if (auto Err = dialectPluginInfoAPIOrErr.takeError()) - return Err; - enumerateDialectsAPI = *enumerateDialectsAPIOrErr; - dialectPluginInfoAPI = *dialectPluginInfoAPIOrErr; +const std::vector TritonPlugin::listPasses() const { + if (!info->passes && info->numPasses > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid pass pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Listing " << info->numPasses + << " passes for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + std::vector passes; + for (auto i = 0; i < info->numPasses; ++i) { + const auto pass = &info->passes[i]; + if (pass->addPass) { + LLVM_DEBUG(llvm::dbgs() << "Listing pass " << pass->name << ":" + << pass->version << "\n"); + passes.push_back(Pass(pass->name, pass->addPass)); + } } - - isLoaded = true; - return llvm::Error::success(); + return passes; } -llvm::Expected TritonPlugin::enumeratePyBindHandles( - EnumeratePyBindHandlesType &enumeratePyBindHandles, - std::vector &handles) { - if (auto Err = loadPlugin()) - return Err; - - uint32_t passCount = 0; - handles.clear(); - auto result = enumeratePyBindHandles(&passCount, nullptr); - if (result == TP_SUCCESS) { - if (passCount == 0) - return TP_SUCCESS; - - handles.resize(passCount); - result = enumeratePyBindHandles(&passCount, handles.data()); +void TritonPlugin::registerPasses() const { + if (!info->passes && info->numPasses > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid pass pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Registering " << info->numPasses + << " passes for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + for (auto i = 0; i < info->numPasses; ++i) { + const auto &pass = info->passes[i]; + if (pass.registerPass) { + LLVM_DEBUG(llvm::dbgs() << "Registering pass " << pass.name << ":" + << pass.version << "\n"); + pass.registerPass(); + } } - - if (result == TP_SUCCESS) - return TP_SUCCESS; - std::string msg; - llvm::raw_string_ostream os(msg); - os << "Failed to retrive plugin pass handles, error code: " << result; - return llvm::createStringError(msg); } -llvm::Expected -TritonPlugin::getPassHandles(std::vector &passNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-passes api symbol is present, bail as - // if there are 0 passes if not - intptr_t isPassPluginSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_PASSES); - if (!isPassPluginSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumeratePassesAPI, passNames); +void TritonPlugin::registerDialects(DialectRegistry &dialectRegistry) const { + if (!info->dialects && info->numDialects > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid dialect pointer in plugin '") + filename + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Registering " << info->numDialects + << " dialects for plugin " << info->pluginName << ":" + << info->pluginVersion << "\n"); + + for (auto i = 0; i < info->numDialects; ++i) { + const auto &dialect = info->dialects[i]; + if (dialect.registerDialect) { + LLVM_DEBUG(llvm::dbgs() << "Registering dialect " << dialect.name << ":" + << dialect.version << "\n"); + dialect.registerDialect(&dialectRegistry); + } + } } -llvm::Expected -TritonPlugin::getDialectHandles(std::vector &dialectNames) { - if (auto Err = loadPlugin()) - return Err; - // Do a check to see if the enumerate-dialects api symbol is present, bail as - // if there are 0 dialects if not - intptr_t isDialectPluginSymbolPresent = - (intptr_t)library.getAddressOfSymbol(ENUMERATE_DIALECTS); - if (!isDialectPluginSymbolPresent) - return TP_SUCCESS; - return enumeratePyBindHandles(enumerateDialectsAPI, dialectNames); +const std::vector TritonPlugin::listOps() const { + if (!info->ops && info->numOps > 0) + llvm::reportFatalUsageError(llvm::createStringError( + llvm::Twine("Invalid custom op pointer in plugin '") + filename + + "'.")); + LLVM_DEBUG(llvm::dbgs() << "Listing " << info->numOps + << " custom ops for plugin " << info->pluginName + << ":" << info->pluginVersion << "\n"); + + std::vector ops; + for (auto i = 0; i < info->numOps; ++i) { + const auto op = &info->ops[i]; + if (op->addOp) { + LLVM_DEBUG(llvm::dbgs() << "Listing custom op " << op->name << "\n"); + ops.push_back(Op(op->name, op->addOp)); + } + } + return ops; } -llvm::Expected -TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) { - if (auto Err = loadPlugin()) - return Err; - return checkAPIResult(addPassAPI(pm, passHandle), passHandle); -} +static std::vector plugins; +static bool pluginsLoaded = false; +const std::vector &mlir::triton::plugin::loadPlugins() { + if (pluginsLoaded) + return plugins; + + // Bailing when libtriton symbols are not visible is done to prevent + // crashes caused by loading plugins that will never find their dependent + // symbols (which are hidden by libtriton). +#if !defined(TRITON_EXT_ENABLED) || TRITON_EXT_ENABLED == 0 + bool skipLoading = true; +#else + bool skipLoading = false; +#endif + + if (const char *env = std::getenv("TRITON_PLUGIN_PATHS")) { + llvm::SmallVector paths; + llvm::StringRef(env).split(paths, ':'); + for (const auto &path : paths) { + if (skipLoading) { + llvm::errs() << "\n" + << "\n=================== WARNING =====================\n" + << "Triton will not load the following extension\n" + << "because it is not built with TRITON_EXT_ENABLED:\n" + << path + << "\n=================================================\n" + << "\n"; + continue; + } + + LLVM_DEBUG(llvm::dbgs() << "Loading plugin from path: " << path << "\n"); + auto pluginOrErr = TritonPlugin::load(path.str()); + if (auto err = pluginOrErr.takeError()) { + llvm::Error wrappedErr = llvm::createStringError( + llvm::Twine("Failed to load plugin from path: ") + path + + ". Error: " + llvm::toString(std::move(err))); + llvm::reportFatalUsageError(std::move(wrappedErr)); + } + plugins.push_back(std::move(*pluginOrErr)); + } + } -llvm::Expected -TritonPlugin::registerPass(const char *passHandle) { - if (auto Err = loadPlugin()) - return Err; - return checkAPIResult(registerPassAPI(passHandle), passHandle); + pluginsLoaded = true; + return plugins; } -llvm::Expected<::mlir::DialectPluginLibraryInfo> -TritonPlugin::getDialectPluginInfo(const char *dialectName) { - if (auto Err = loadPlugin()) - return Err; - return dialectPluginInfoAPI(dialectName); -} +#undef DEBUG_TYPE diff --git a/python/src/ir.cc b/python/src/ir.cc index 2c27b5a5aec1..4024b9178382 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -262,7 +262,6 @@ py::list getTensorDescMetadata(ModuleOp &mod) { /*****************************************************************************/ /* Python bindings for ir */ /*****************************************************************************/ - void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; @@ -367,23 +366,9 @@ void init_triton_ir(py::module &&m) { m.def("load_dialects", [](MLIRContext &context) { DialectRegistry registry; - if (std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - !filename.empty()) { - TritonPlugin TP(filename); - - std::vector dialectNames; - if (auto result = TP.getDialectHandles(dialectNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (unsigned i = 0; i < dialectNames.size(); ++i) { - const char *dialectName = dialectNames.data()[i]; - auto result = TP.getDialectPluginInfo(dialectName); - if (!result) - throw TP.err2exp(result.takeError()); - ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; - dialectPluginInfo.registerDialectRegistryCallbacks(®istry); - } + // Register plugin dialects. + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + plugin.registerDialects(registry); } registry.insert(m, "InsertPoint", py::module_local()); - py::class_(m, "builder", py::module_local(), - py::dynamic_attr()) - .def(py::init()) + py::class_ TritonOpBuilderBinding = + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()); + TritonOpBuilderBinding.def(py::init()) .def("get_op_builder", &TritonOpBuilder::getBuilder, ret::reference) // getters .def("create_module", @@ -1176,7 +1162,8 @@ void init_triton_ir(py::module &&m) { }) // Cast instructions - // Conversions for custom FP types (FP8 and non-standard rounding modes) + // Conversions for custom FP types (FP8 and non-standard rounding + // modes) .def("create_fp_to_fp", [](TritonOpBuilder &self, Value &src, Type &dstType, std::optional roundingMode) -> Value { @@ -1317,8 +1304,8 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); }) - // minimumf follows the torch.minimum convention and returns NaN if either - // operand is NaN + // minimumf follows the torch.minimum convention and returns NaN if + // either operand is NaN .def("create_minimumf", [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); @@ -1337,8 +1324,8 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); }) - // maximumf follows the torch.maximum convention and returns NaN if either - // operand is NaN + // maximumf follows the torch.maximum convention and returns NaN if + // either operand is NaN .def("create_maximumf", [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); @@ -1878,6 +1865,16 @@ void init_triton_ir(py::module &&m) { paddingOption); }); + // Add custom operations. + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + for (const auto &op : plugin.listOps()) { + TritonOpBuilderBinding.def( + op.name, [op](TritonOpBuilder &self, std::vector args) { + op.addOp(self, args); + }); + } + } + py::class_(m, "pass_manager", py::module_local()) .def(py::init()) .def("enable_debug", diff --git a/python/src/ir.h b/python/src/ir.h index 499dd9e8a9f6..f8dd9b2941ac 100644 --- a/python/src/ir.h +++ b/python/src/ir.h @@ -1,7 +1,9 @@ #pragma once #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectRegistry.h" #include "triton/Tools/Sys/GetEnv.hpp" #include +#include // A custom op builder that keeps track of the last location class TritonOpBuilder { diff --git a/python/src/passes.cc b/python/src/passes.cc index 8977b59913a4..c29fac2d35c3 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -100,26 +100,15 @@ void init_triton_passes_ttgpuir(py::module &&m) { } void init_plugin_passes(py::module &&m) { - std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - if (filename.empty()) - return; - - TritonPlugin TP(filename); - std::vector passNames; - if (auto result = TP.getPassHandles(passNames); !result) - throw TP.err2exp(result.takeError()); - - for (unsigned i = 0; i < passNames.size(); ++i) { - const char *passName = passNames.data()[i]; - - m.def(passName, [passName](mlir ::PassManager &pm) { - std::string filename = - mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); - TritonPlugin TP(filename); - if (auto result = TP.addPass(&pm, passName); !result) - throw TP.err2exp(result.takeError()); - }); + for (const auto &plugin : mlir::triton::plugin::loadPlugins()) { + for (const auto &pass : plugin.listPasses()) { + m.def( + pass.name, + [pass](mlir::PassManager &pm, std::vector args) { + pass.addPass(&pm, args); + }, + py::arg("pm"), py::arg("args") = std::vector()); + } } } diff --git a/python/test/unit/plugins/custom_ops.py b/python/test/unit/plugins/custom_ops.py new file mode 100644 index 000000000000..8821c13393f2 --- /dev/null +++ b/python/test/unit/plugins/custom_ops.py @@ -0,0 +1,72 @@ +import torch + +import triton +import triton.language as tl +from triton._C.libtriton import ir +from triton.language.core import builtin +from typing import TypeVar, Type +import builtins +import os +import pathlib +from triton.compiler.code_generator import flatten_values_to_ir + +T = TypeVar('T') +TensorTy = TypeVar('TensorTy') + +triton.language.__all__.append("custom_op") +tensor: Type[TensorTy] = tl.tensor +builder: ir.builder + +TRITON_BUILTIN = "__triton_builtin__" + + +def _unwrap_if_constexpr(o): + if isinstance(o, list): + return [_unwrap_if_constexpr(x) for x in o] + if isinstance(o, builtins.tuple): + return builtins.tuple(_unwrap_if_constexpr(x) for x in o) + if isinstance(o, tuple): + return tuple(_unwrap_if_constexpr(x) for x in o) + return o.value if isinstance(o, tl.constexpr) else o + + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@builtin +def custom_op(x, sanitize_overflow: tl.constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + builder = _semantic.builder + arg_handles = [] + arg_handles.extend(flatten_values_to_ir([x])) + return tl.tensor(builder.create_custom_op(arg_handles), x.type) + + +@triton.jit +def add_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = custom_op(x) + tl.store(output_ptr + offsets, output, mask=mask) + + +def test_custom_ops(tmp_path: pathlib.Path): + if os.environ.get('TRITON_EXT_ENABLED', '0') == '0': + return + size = 8 + x = torch.zeros(size, device=DEVICE, dtype=torch.float32) + output_triton = torch.empty_like(x) + n_elements = output_triton.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + h = add_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=32) + + src = h.asm["source"] + assert "arith.addf" in src diff --git a/python/test/unit/plugins/custom_stages.py b/python/test/unit/plugins/custom_stages.py index f3fda7c4db50..de52de3bf289 100644 --- a/python/test/unit/plugins/custom_stages.py +++ b/python/test/unit/plugins/custom_stages.py @@ -20,6 +20,8 @@ def get_hash(): # Keep custom pipeline stages in a seperate file from kernels as any change to the file # will trigger a recompile. +num_warps = 4 + def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None): # If the hook is called with no arguments we assume were just after the key and hash and don't want to @@ -31,7 +33,10 @@ def make_ttir_wrapper(mod, metadata, opt, capability): mod = self.make_ttir(mod, metadata, opt, capability) pm = ir.pass_manager(mod.context) pm.enable_debug() - passes.plugin.add_plugin(pm) + if num_warps != 4: + passes.plugin.add_plugin(pm, {str(num_warps)}) + else: + passes.plugin.add_plugin(pm) pm.run(mod, 'make_ttir_plugin') return mod diff --git a/python/test/unit/plugins/test_dialect_plugin.py b/python/test/unit/plugins/test_dialect_plugin.py index 55ba6c36950f..5c4289534544 100644 --- a/python/test/unit/plugins/test_dialect_plugin.py +++ b/python/test/unit/plugins/test_dialect_plugin.py @@ -9,7 +9,7 @@ def test_override(tmp_path: pathlib.Path): - if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': + if os.environ.get('TRITON_EXT_ENABLED', '0') == '0': return dir_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/python/test/unit/plugins/test_plugin.py b/python/test/unit/plugins/test_plugin.py index 9a895174b1b8..4caaeef5bf2f 100644 --- a/python/test/unit/plugins/test_plugin.py +++ b/python/test/unit/plugins/test_plugin.py @@ -21,8 +21,14 @@ def kernel2(BLOCK_SIZE: tl.constexpr): return +@pytest.mark.parametrize(None, [None]) +@triton.jit +def kernel3(BLOCK_SIZE: tl.constexpr): + return + + def test_op(capfd, device: str): - if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': + if os.environ.get('TRITON_EXT_ENABLED', '0') == '0': return size = 98432 @@ -41,3 +47,8 @@ def test_op(capfd, device: str): knobs.runtime.add_stages_inspection_hook = None h = kernel2[grid](BLOCK_SIZE=1024) assert "tt.func public @foo" not in h.asm["ttgir"] + + knobs.runtime.add_stages_inspection_hook = custom_stages.inspect_stages_hook + custom_stages.num_warps = 8 + h = kernel3[grid](BLOCK_SIZE=1024) + assert "tt.func public @foo_num_warps_8" in h.asm["ttgir"] diff --git a/scripts/build-llvm-project.sh b/scripts/build-llvm-project.sh index 48e24ab4d8f9..a9921ee1fa8c 100755 --- a/scripts/build-llvm-project.sh +++ b/scripts/build-llvm-project.sh @@ -5,7 +5,6 @@ REPO_ROOT="$(git rev-parse --show-toplevel)" LLVM_TARGETS=${LLVM_TARGETS:-Native;NVPTX;AMDGPU} LLVM_PROJECTS=${LLVM_PROJECTS:-mlir;llvm;lld} LLVM_BUILD_TYPE=${LLVM_BUILD_TYPE:-RelWithDebInfo} -LLVM_BUILD_SHARED_LIBS=${LLVM_BUILD_SHARED_LIBS:-OFF} LLVM_COMMIT_HASH=${LLVM_COMMIT_HASH:-$(cat "$REPO_ROOT/cmake/llvm-hash.txt")} LLVM_PROJECT_PATH=${LLVM_PROJECT_PATH:-"$REPO_ROOT/llvm-project"} LLVM_BUILD_PATH=${LLVM_BUILD_PATH:-"$LLVM_PROJECT_PATH/build"} @@ -22,7 +21,6 @@ if [ -z "$CMAKE_ARGS" ]; then -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DLLVM_ENABLE_LLD=ON - -DBUILD_SHARED_LIBS="$LLVM_BUILD_SHARED_LIBS" -DLLVM_OPTIMIZED_TABLEGEN=ON -DMLIR_ENABLE_BINDINGS_PYTHON=OFF -DLLVM_ENABLE_ZSTD=OFF diff --git a/setup.py b/setup.py index 8cccc31a10fb..0388a93d6d8a 100644 --- a/setup.py +++ b/setup.py @@ -465,7 +465,8 @@ def build_extension(self, ext): "-DPython3_EXECUTABLE:FILEPATH=" + sys.executable, "-DPython3_INCLUDE_DIR=" + python_include_dir, "-DTRITON_CODEGEN_BACKENDS=" + ';'.join([b.name for b in backends if not b.is_external]), "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]), - "-DTRITON_WHEEL_DIR=" + wheeldir + "-DTRITON_WHEEL_DIR=" + wheeldir, + f"-DTRITON_VERSION={TRITON_VERSION}", ] if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir) @@ -492,10 +493,10 @@ def build_extension(self, ext): "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", ] - if check_env_flag("LLVM_BUILD_SHARED_LIBS"): - cmake_args += ["-DLLVM_BUILD_SHARED_LIBS=1"] + if check_env_flag("TRITON_EXT_ENABLED"): + cmake_args += ["-DTRITON_EXT_ENABLED=1"] else: - cmake_args += ["-DLLVM_BUILD_SHARED_LIBS=0"] + cmake_args += ["-DTRITON_EXT_ENABLED=0"] # Note that asan doesn't work with binaries that use the GPU, so this is # only useful for tools like triton-opt that don't run code on the GPU. diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 793cdd718794..cc5654f7cb8b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,7 +2,7 @@ add_subdirectory(lib) llvm_canonicalize_cmake_booleans( MLIR_ENABLE_BINDINGS_PYTHON - LLVM_BUILD_SHARED_LIBS + TRITON_EXT_ENABLED ) configure_lit_site_cfg( diff --git a/test/Plugins/test-dialect-plugin.mlir b/test/Plugins/test-dialect-plugin.mlir index 5751e94b5ccc..35ccbf4ef199 100644 --- a/test/Plugins/test-dialect-plugin.mlir +++ b/test/Plugins/test-dialect-plugin.mlir @@ -1,7 +1,11 @@ -// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libMLIRDialectPlugin.so triton-opt -split-input-file --convert-plugin-gpu-to-llvm --convert-triton-gpu-to-llvm %s | \ +// RUN: LD_PRELOAD=%shlibdir/../plugins/libtriton.so \ +// RUN: TRITON_PLUGIN_PATHS=%shlibdir/../plugins/libMLIRDialectPlugin.so \ +// RUN: triton-opt \ +// RUN: -split-input-file --convert-plugin-gpu-to-llvm --convert-triton-gpu-to-llvm %s | \ // RUN: FileCheck %s -// REQUIRES: shared-libs +// REQUIRES: triton-ext-enabled +// XFAIL: * #blocked0 = #ttg.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> module attributes {"ttg.num-warps" = 8 : i32} { diff --git a/test/Plugins/test-plugin.mlir b/test/Plugins/test-plugin.mlir index f16ef0788240..ec09469dc293 100644 --- a/test/Plugins/test-plugin.mlir +++ b/test/Plugins/test-plugin.mlir @@ -1,8 +1,17 @@ -// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so triton-opt -split-input-file -tritongpu-plugin %s | FileCheck %s --check-prefix=CHECK-PLUGIN -// RUN: TRITON_PASS_PLUGIN_PATH=%shlibdir/../plugins/libTritonPluginsTestLib.so triton-opt -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-NOFLAG +// RUN: LD_PRELOAD=%shlibdir/../plugins/libtriton.so \ +// RUN: TRITON_PLUGIN_PATHS=%shlibdir/../plugins/libTritonPluginsTestLib.so \ +// RUN: triton-opt \ +// RUN: -split-input-file -tritongpu-plugin %s | FileCheck %s --check-prefix=CHECK-PLUGIN + +// RUN: LD_PRELOAD=%shlibdir/../plugins/libtriton.so \ +// RUN: TRITON_PLUGIN_PATHS=%shlibdir/../plugins/libTritonPluginsTestLib.so \ +// RUN: triton-opt \ +// RUN: -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-NOFLAG + // RUN: triton-opt -split-input-file %s | FileCheck %s -allow-unused-prefixes --check-prefix=CHECK-BASE -// REQUIRES: shared-libs +// REQUIRES: triton-ext-enabled +// XFAIL: * module attributes {"ttg.num-warps" = 4 : i32, "ttg.target" = "cuda:80"} { // CHECK-PLUGIN: func @foo() diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 09044bebe457..1dfc0c5cf570 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -63,9 +63,9 @@ ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), ] -# Static libraries are not built if LLVM_BUILD_SHARED_LIBS is ON. -if config.build_shared_libs: - config.available_features.add("shared-libs") +# Static libraries are not built if TRITON_EXT_ENABLED is ON. +if config.triton_ext_enabled: + config.available_features.add("triton-ext-enabled") llvm_config.add_tool_substitutions(tools, tool_dirs) diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 59b212a4d227..90e4ae71ef52 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -14,7 +14,7 @@ config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.mlir_binary_dir = "@MLIR_BINARY_DIR@" config.python_executable = "@Python3_EXECUTABLE@" config.enable_bindings_python = @MLIR_ENABLE_BINDINGS_PYTHON@ -config.build_shared_libs = @LLVM_BUILD_SHARED_LIBS@ +config.triton_ext_enabled = @TRITON_EXT_ENABLED@ import lit.llvm