Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
option(TRITON_OFFLINE_BUILD "Build without downloading dependencies" OFF)
option(TRITON_EXT_ENABLED "Build with default visibility for Triton+LLVM symbol exposure to plugin extensions" OFF)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
set(TRITON_VERSION "" CACHE STRING "Triton version string (passed from setup.py)")

set(TRITON_CACHE_PATH "" CACHE PATH "Path to triton cache")
set(TRITON_LLVM_SYSTEM_SUFFIX "" CACHE STRING "Path to LLVM system suffix")
Expand Down Expand Up @@ -96,6 +97,13 @@ if(NOT "${JSON_SYSPATH}" STREQUAL "" AND NOT DEFINED JSON_INCLUDE_DIR)
set(JSON_INCLUDE_DIR "${JSON_SYSPATH}/include")
endif()

# Regenerate Triton Version Header
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/include/triton/Version.h.in"
"${CMAKE_CURRENT_BINARY_DIR}/include/triton/Version.h"
@ONLY
)

# Regenerate configure outputs during `cmake --build` when helper inputs change.
set_property(
DIRECTORY
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
dialects,
1,
ops,
1};
1,
TRITON_VERSION};
return &info;
}
3 changes: 2 additions & 1 deletion examples/plugins/TritonPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
nullptr,
0,
nullptr,
0};
0,
TRITON_VERSION};
return &info;
}
6 changes: 5 additions & 1 deletion include/triton/Tools/PluginUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"
#include "python/src/ir.h"
#include "triton/Version.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"
Expand All @@ -30,7 +31,7 @@
///
/// [MLIR_PLUGIN_API_VERSION]:
/// https://github.com/llvm/llvm-project/blob/80d6e0b8/mlir/include/mlir/Tools/Plugins/PassPlugin.h#L32
#define TRITON_PLUGIN_API_VERSION 1
#define TRITON_PLUGIN_API_VERSION 2

/// Use this helper macro on the public entry point for a Triton plugin.
#define TRITON_PLUGIN_API extern "C" __attribute__((visibility("default")))
Expand Down Expand Up @@ -87,6 +88,9 @@ struct PluginInfo {
/// The list of custom ops.
OpInfo *ops;
size_t numOps;

/// Triton Version
const char *tritonVersion;
};

/// A helper structure for storing information about a pass registered by a
Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_PREFER_TMEM_16x256_LAYOUT",
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
"TRITON_PLUGIN_PATHS",
"TRITON_PLUGIN_VERSION_CHECK",
"TRITON_PARTITION_SCHEDULING_ENABLE_DUMP_DOT",
"TRITON_PARTITION_SCHEDULING_DUMP_DATA_ONLY",
"TRITON_PARTITION_SCHEDULING_DUMP_LOOP_ONLY",
Expand Down
6 changes: 6 additions & 0 deletions include/triton/Version.h.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#ifndef TRITON_VERSION_H
#define TRITON_VERSION_H

#define TRITON_VERSION "@TRITON_VERSION@"

#endif // TRITON_VERSION_H
29 changes: 29 additions & 0 deletions lib/Tools/PluginUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
#include "triton/Tools/PluginUtils.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Error.h"

#define DEBUG_TYPE "triton-plugins"

using namespace mlir::triton::plugin;

static bool isTritonAndPluginsVersionsMatch(const std::string &pluginVersion) {
// Here, if TRITON_PLUGIN_VERSION_CHECK is unset, then we simply do a default
// version check. However, if it is set then we either do a full (git hash)
// check or we skip all checking.
auto doCheck =
mlir::triton::tools::isEnvValueBool("TRITON_PLUGIN_VERSION_CHECK");

// Skip check when TRITON_PLUGIN_VERSION_CHECK is set false
if (doCheck.has_value() && !doCheck.value())
return true;

// Check full version string when TRITON_PLUGIN_VERSION_CHECK is set true
if (doCheck.has_value() && doCheck.value())
return pluginVersion == TRITON_VERSION;

// Do partial release version check when TRITON_PLUGIN_VERSION_CHECK unset
assert(!doCheck.has_value() && "Expected TRITON_PLUGIN_VERSION_CHECK unset");
return llvm::StringRef(pluginVersion).split('+').first ==
llvm::StringRef(TRITON_VERSION).split('+').first;
}

llvm::Expected<TritonPlugin> TritonPlugin::load(const std::string &filename) {
std::string error;
auto library =
Expand Down Expand Up @@ -35,6 +57,13 @@ llvm::Expected<TritonPlugin> TritonPlugin::load(const std::string &filename) {
Twine(TRITON_PLUGIN_API_VERSION) + ".",
llvm::inconvertibleErrorCode());

if (!isTritonAndPluginsVersionsMatch(plugin.info->tritonVersion))
return llvm::make_error<llvm::StringError>(
Twine("Wrong TRITON version on plugin '") + filename +
"'. Got version " + Twine(plugin.info->tritonVersion) +
", supported version is " + Twine(TRITON_VERSION) + ".",
llvm::inconvertibleErrorCode());

return plugin;
}

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def build_extension(self, ext):
"-DTRITON_PLUGIN_DIRS=" + ';'.join([b.src_dir for b in backends if b.is_external]),
"-DTRITON_WHEEL_DIR=" + wheeldir,
f"-DTRITON_CACHE_PATH={get_triton_cache_path()}",
f"-DTRITON_VERSION={TRITON_VERSION}",
]
if lit_dir is not None:
cmake_args.append("-DLLVM_EXTERNAL_LIT=" + lit_dir)
Expand Down
Loading