diff --git a/CMakeLists.txt b/CMakeLists.txt index 538f3332c4a7..f731916d8254 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,6 +24,7 @@ option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON) option(TRITON_OFFLINE_BUILD "Build without downloading dependencies" OFF) option(TRITON_EXT_ENABLED "Build with default visibility for Triton+LLVM symbol exposure to plugin extensions" OFF) set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends") +set(TRITON_VERSION "" CACHE STRING "Triton version string (passed from setup.py)") set(TRITON_CACHE_PATH "" CACHE PATH "Path to triton cache") set(TRITON_LLVM_SYSTEM_SUFFIX "" CACHE STRING "Path to LLVM system suffix") @@ -96,6 +97,13 @@ if(NOT "${JSON_SYSPATH}" STREQUAL "" AND NOT DEFINED JSON_INCLUDE_DIR) set(JSON_INCLUDE_DIR "${JSON_SYSPATH}/include") endif() +# Regenerate Triton Version Header +configure_file( + "${CMAKE_CURRENT_SOURCE_DIR}/include/triton/Version.h.in" + "${CMAKE_CURRENT_BINARY_DIR}/include/triton/Version.h" + @ONLY +) + # Regenerate configure outputs during `cmake --build` when helper inputs change. set_property( DIRECTORY diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp index 0349a8a9a1a5..bd4687fb91c7 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp @@ -78,6 +78,7 @@ TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { dialects, 1, ops, - 1}; + 1, + TRITON_VERSION}; return &info; } diff --git a/examples/plugins/TritonPlugin.cpp b/examples/plugins/TritonPlugin.cpp index 6c18f0c3a073..33a252f94b81 100644 --- a/examples/plugins/TritonPlugin.cpp +++ b/examples/plugins/TritonPlugin.cpp @@ -75,6 +75,7 @@ TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() { nullptr, 0, nullptr, - 0}; + 0, + TRITON_VERSION}; return &info; } diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index fd1c149942f7..6cde31819cdc 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -16,6 +16,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Tools/Plugins/DialectPlugin.h" #include "python/src/ir.h" +#include "triton/Version.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/Error.h" @@ -30,7 +31,7 @@ /// /// [MLIR_PLUGIN_API_VERSION]: /// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h#L32 -#define TRITON_PLUGIN_API_VERSION 1 +#define TRITON_PLUGIN_API_VERSION 2 /// Use this helper macro on the public entry point for a Triton plugin. #define TRITON_PLUGIN_API extern "C" __attribute__((visibility("default"))) @@ -87,6 +88,9 @@ struct PluginInfo { /// The list of custom ops. OpInfo *ops; size_t numOps; + + /// Triton Version + const char *tritonVersion; }; /// A helper structure for storing information about a pass registered by a diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 4cc051c6219f..8c0d8611cf2f 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -47,6 +47,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_PREFER_TMEM_16x256_LAYOUT", "TRITON_ENABLE_EXPERIMENTAL_CONSAN", "TRITON_PLUGIN_PATHS", + "TRITON_PLUGIN_VERSION_CHECK", "TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT", "TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY", "TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY", diff --git a/include/triton/Version.h.in b/include/triton/Version.h.in new file mode 100644 index 000000000000..66219ab3a6ea --- /dev/null +++ b/include/triton/Version.h.in @@ -0,0 +1,6 @@ +#ifndef TRITON_VERSION_H +#define TRITON_VERSION_H + +#define TRITON_VERSION "@TRITON_VERSION@" + +#endif // TRITON_VERSION_H diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp index 6311907cde91..05bbcc89fb09 100644 --- a/lib/Tools/PluginUtils.cpp +++ b/lib/Tools/PluginUtils.cpp @@ -1,4 +1,5 @@ #include "triton/Tools/PluginUtils.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" #include "llvm/Support/Error.h" @@ -6,6 +7,27 @@ using namespace mlir::triton::plugin; +static bool isTritonAndPluginsVersionsMatch(const std::string &pluginVersion) { + // Here, if TRITON_PLUGIN_VERSION_CHECK is unset, then we simply do a default + // version check. However, if it is set then we either do a full (git hash) + // check or we skip all checking. + auto doCheck = + mlir::triton::tools::isEnvValueBool("TRITON_PLUGIN_VERSION_CHECK"); + + // Skip check when TRITON_PLUGIN_VERSION_CHECK is set false + if (doCheck.has_value() && !doCheck.value()) + return true; + + // Check full version string when TRITON_PLUGIN_VERSION_CHECK is set true + if (doCheck.has_value() && doCheck.value()) + return pluginVersion == TRITON_VERSION; + + // Do partial release version check when TRITON_PLUGIN_VERSION_CHECK unset + assert(!doCheck.has_value() && "Expected TRITON_PLUGIN_VERSION_CHECK unset"); + return llvm::StringRef(pluginVersion).split('+').first == + llvm::StringRef(TRITON_VERSION).split('+').first; +} + llvm::Expected TritonPlugin::load(const std::string &filename) { std::string error; auto library = @@ -35,6 +57,13 @@ llvm::Expected TritonPlugin::load(const std::string &filename) { Twine(TRITON_PLUGIN_API_VERSION) + ".", llvm::inconvertibleErrorCode()); + if (!isTritonAndPluginsVersionsMatch(plugin.info->tritonVersion)) + return llvm::make_error( + Twine("Wrong TRITON version on plugin '") + filename + + "'. Got version " + Twine(plugin.info->tritonVersion) + + ", supported version is " + Twine(TRITON_VERSION) + ".", + llvm::inconvertibleErrorCode()); + return plugin; } diff --git a/setup.py b/setup.py index ed1900db4294..9e5f0f8c219a 100644 --- a/setup.py +++ b/setup.py @@ -285,6 +285,7 @@ def build_extension(self, ext): "-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]), "-DTRITON_WHEEL_DIR=" + wheeldir, f"-DTRITON_CACHE_PATH={get_triton_cache_path()}", + f"-DTRITON_VERSION={TRITON_VERSION}", ] if lit_dir is not None: cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)