Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
option(TRITON_BUILD_PROTON "Build the Triton Proton profiler" ON)
option(TRITON_BUILD_UT "Build C++ Triton Unit Tests" ON)
option(TRITON_BUILD_WITH_CCACHE "Build with ccache (if available)" ON)
option(LLVM_BUILD_SHARED_LIBS
"Build all libraries as shared libraries instead of static" OFF)
option(TRITON_OFFLINE_BUILD "Build without downloading dependencies" OFF)
option(TRITON_EXT_ENABLED "Build with default visibility for Triton+LLVM symbol exposure to plugin extensions" OFF)
set(TRITON_CODEGEN_BACKENDS "" CACHE STRING "Enable different codegen backends")
set(TRITON_VERSION "" CACHE STRING "Triton version string (passed from setup.py)")

if(TRITON_BUILD_WITH_CCACHE)
find_program(CCACHE_PROGRAM ccache)
Expand Down Expand Up @@ -62,7 +63,6 @@ else()
set(CMAKE_EXE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_MODULE_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(CMAKE_SHARED_LINKER_FLAGS_TRITONRELBUILDWITHASSERTS "/debug:fastlink /INCREMENTAL")
set(LLVM_BUILD_SHARED_LIBS "0")
endif()

# Default build type
Expand All @@ -86,6 +86,12 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__STDC_FORMAT_MACROS")
endif()

# Regenerate Triton Version Header
configure_file(
"${CMAKE_CURRENT_SOURCE_DIR}/include/triton/Version.h.in"
"${CMAKE_CURRENT_BINARY_DIR}/include/triton/Version.h"
@ONLY
)

# #########
# LLVM
Expand Down Expand Up @@ -144,7 +150,17 @@ endfunction()
# Disable warnings that show up in external code (gtest;pybind11)
if(NOT MSVC)
set(TRITON_DISABLE_EH_RTTI_FLAGS "$<$<COMPILE_LANGUAGE:CXX>:-fno-exceptions;-fno-rtti>")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")

if(NOT TRITON_EXT_ENABLED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden")
else()
# Inform plugin loader if that libtriton is compiled with visibility
# so that they will not proceed with loading plugins if that will
# crash a Triton not compiled with visibility due to missing symbols.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTRITON_EXT_ENABLED=1")
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default ")

else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4244 /wd4624 /wd4715 /wd4530")
endif()
Expand Down Expand Up @@ -293,6 +309,13 @@ if(TRITON_BUILD_PYTHON_MODULE)

# Link triton with its dependencies
target_link_libraries(triton PRIVATE ${TRITON_LIBRARIES})


# Do not propagate libraries that libtriton depends on. This ensures that
# targets that link against libtriton do not accidentally link in their own
# copies of core Triton code and LLVM.
set_target_properties(triton PROPERTIES INTERFACE_LINK_LIBRARIES "")

if(WIN32)
target_link_libraries(triton PRIVATE ${CMAKE_DL_LIBS})
set_target_properties(triton PROPERTIES SUFFIX ".pyd")
Expand All @@ -315,9 +338,11 @@ if(TRITON_BUILD_PYTHON_MODULE)
"${TRITON_WHEEL_DIR}/FileCheck"
COPYONLY)

# Build plugins when building libtriton since they depend on libtriton.
add_subdirectory(examples/plugins)
endif()

if (UNIX AND NOT APPLE)
if (UNIX AND NOT APPLE AND NOT TRITON_EXT_ENABLED)
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
endif()

Expand All @@ -344,7 +369,6 @@ find_package(Threads REQUIRED)
add_subdirectory(third_party/f2reduce)
add_subdirectory(bin)
add_subdirectory(test)
add_subdirectory(examples)

if(TRITON_BUILD_UT)
add_subdirectory(unittest)
Expand Down
6 changes: 4 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ test-unit: all
$(PYTEST) --tb=short -vs python/examples/gluon/01-attention-forward.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
TRITON_PLUGIN_PATHS=python/triton/plugins/libTritonPluginsTestLib.so \
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -s -vvv python/test/unit/plugins/test_dialect_plugin.py
$(PYTEST) --tb=short -s -n $(NUM_PROCS) python/test/gluon
TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py

.PHONY: test-gluon
test-gluon: all
Expand Down
30 changes: 4 additions & 26 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,32 +146,10 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
mlir::triton::proton::gpu::registerAddSchedBarriersPass();

// Plugin passes
if (std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
!filename.empty()) {

TritonPlugin TP(filename);
std::vector<const char *> passNames;
if (auto result = TP.getPassHandles(passNames); !result)
llvm::report_fatal_error(result.takeError());

for (const char *passName : passNames)
if (auto result = TP.registerPass(passName); !result)
llvm::report_fatal_error(result.takeError());

std::vector<const char *> dialectNames;
if (auto result = TP.getDialectHandles(dialectNames); !result)
llvm::report_fatal_error(result.takeError());

for (unsigned i = 0; i < dialectNames.size(); ++i) {
const char *dialectName = dialectNames.data()[i];
auto result = TP.getDialectPluginInfo(dialectName);
if (!result)
llvm::report_fatal_error(result.takeError());
::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result;
dialectPluginInfo.registerDialectRegistryCallbacks(&registry);
}
// Register plugin passes and dialects.
for (const auto &plugin : mlir::triton::plugin::loadPlugins()) {
plugin.registerPasses();
plugin.registerDialects(registry);
}

registry.insert<
Expand Down
1 change: 0 additions & 1 deletion examples/CMakeLists.txt

This file was deleted.

2 changes: 1 addition & 1 deletion examples/plugins/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ foreach( plugin ${TRITON_PLUGIN_PASSES} )
TritonCanonicalizeIncGen
TritonPluginsIncGen
)
target_link_libraries(${plugin} PRIVATE MLIRPass)
target_link_libraries(${plugin} PRIVATE triton)

# CMAKE_LIBRARY_OUTPUT_DIRECTORY is only set during the Python
# build. It is empty if building directly from the root
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ add_mlir_dialect_library(MLIRDialectPlugin
MLIRDialectPluginPassesIncGen

LINK_LIBS PUBLIC
MLIRPass
LLVMSupport
MLIRSupport
TritonNVIDIAGPUToLLVM
triton

"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "DialectPlugin/DialectPluginTypes.h"

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::plugin;

#include "DialectPlugin/DialectPluginOpsDialect.cpp.inc"
Expand All @@ -25,13 +26,18 @@ void DialectPluginDialect::initialize() {

#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginPasses.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "triton/Tools/PluginUtils.h"
#include "llvm/Config/llvm-config.h"

using namespace mlir;
static const char *PLUGIN_NAME = "DialectPlugin";
static const char *DIALECT_NAME = "DialectPlugin";
static const char *PASS_NAME = "plugingpu_conversion";
static const char *VERSION = "0.1.0";

static void addTritonPluginPass(mlir::PassManager *pm) {
static void addTritonPluginPass(mlir::PassManager *pm,
const std::vector<std::string> &args) {
pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass());
}

Expand All @@ -41,65 +47,38 @@ static void registerTritonPluginPass() {
});
}

static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion";
static std::unordered_map<std::string, void (*)(mlir::PassManager *)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)()> registryMap = {
{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};

// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonRegisterPluginPass(const char *passName) {
std::string passNameStr(passName);
if (registryMap.find(passNameStr) == registryMap.end())
return TP_GENERIC_FAILURE;
registryMap[passNameStr]();
return TP_SUCCESS;
static void registerTritonPluginDialect(DialectRegistry *registry) {
registry->insert<mlir::triton::plugin::DialectPluginDialect>();
mlir::triton::plugin::registerpluginPasses();
}

TRITON_PLUGIN_API
tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) {
if (!passCount)
return TP_GENERIC_FAILURE;
auto count = passMap.size();
assert(count == registryMap.size() &&
"Expected register and add passes map size to match");
*passCount = count;
if (!passNames)
return TP_SUCCESS;
unsigned i = 0;
for (auto passName : passNamesTable) {
passNames[i] = passName;
}
return TP_SUCCESS;
}
static void addTritonPluginCustomOp(TritonOpBuilder &self,
std::vector<mlir::Value> &operands) {
::mlir::Value &dst = operands[0];
::mlir::Value &src = operands[1];

TRITON_PLUGIN_API
tritonEnumeratePluginDialects(uint32_t *dialectCount,
const char **dialectNames) {
*dialectCount = 1;
if (!dialectNames)
return TP_SUCCESS;
dialectNames[0] = "DialectPlugin";
return TP_SUCCESS;
dst = self.create<arith::AddFOp>(src, src);
operands[0] = dst;
}

TRITON_PLUGIN_API_TYPE(DialectPluginLibraryInfo)
tritonGetDialectPluginInfo(const char *name) {
return {MLIR_PLUGIN_API_VERSION, "DialectPlugin", LLVM_VERSION_STRING,
[](DialectRegistry *registry) {
registry->insert<mlir::triton::plugin::DialectPluginDialect>();
mlir::triton::plugin::registerpluginPasses();
}};
TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass,
registerTritonPluginPass};
static plugin::PassInfo passes[] = {pass};
static plugin::DialectInfo dialect = {DIALECT_NAME, VERSION,
registerTritonPluginDialect};
static plugin::DialectInfo dialects[] = {dialect};
static plugin::OpInfo op = {"create_custom_op", addTritonPluginCustomOp};
static plugin::OpInfo ops[] = {op};
static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION,
PLUGIN_NAME,
VERSION,
passes,
1,
dialects,
1,
ops,
1,
TRITON_VERSION};
return &info;
}
7 changes: 7 additions & 0 deletions examples/plugins/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,12 @@ include "mlir/Pass/PassBase.td"

def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> {
let summary = "Triton MLIR Plugin Pass";

let options = [
Option<"num_warps", "num-warps",
"int32_t", /*default*/"4",
"Number of warps">,
];

}
#endif
12 changes: 6 additions & 6 deletions examples/plugins/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ long as the libtriton.so is linked to the plugin and the Triton include passes a

## Example 1: Developing a custom pass and running triton-opt to inspect the modified IR
``` bash
export LLVM_BUILD_SHARED_LIBS=1; make dev-install-llvm
TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir
export TRITON_EXT_ENABLED=1; make dev-install-llvm
TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir
```
``` MLIR
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
Expand Down Expand Up @@ -85,7 +85,7 @@ Running same code but loading the plugin library also produces the same results
pass manager it is not inserted into the compiler pass pipeline:

``` bash
TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
```

``` MLIR
Expand Down Expand Up @@ -151,7 +151,7 @@ if __name__ == '__main__':
h = kernel[grid](BLOCK_SIZE=1024)
print(h.asm["ttgir"])

if "TRITON_PASS_PLUGIN_PATH" in os.environ:
if "TRITON_PLUGIN_PATHS" in os.environ:
knobs.runtime.add_stages_inspection_hook = inspect_stages_hook
h = kernel[grid](BLOCK_SIZE=1024)
print(h.asm["ttgir"])
Expand All @@ -163,7 +163,7 @@ if __name__ == '__main__':
```

``` bash
TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
```

Shows the pass ran and modified the kernel name but only after the hook is set. Any kernels before the hook or after the hook is unset are left unchanged.
Expand Down Expand Up @@ -330,7 +330,7 @@ if __name__ == '__main__':
if "add_loop_unroll" in line:
outfile.write("\n passes.plugin.add_plugin(pm)\n")
outfile.write(line)
if "TRITON_PASS_PLUGIN_PATH" in os.environ:
if "TRITON_PLUGIN_PATHS" in os.environ:
knobs.runtime.add_stages_inspection_hook = override_stages
h = kernel2[grid](BLOCK_SIZE=1024)
print(h.asm["ttgir"])
Expand Down
Loading
Loading