Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ test-unit: all
$(PYTEST) python/tutorials/06-fused-attention.py
TRITON_ALWAYS_COMPILE=1 TRITON_DISABLE_LINE_INFO=0 LLVM_PASS_PLUGIN_PATH=python/triton/instrumentation/libGPUInstrumentationTestLib.so \
$(PYTEST) --capture=tee-sys -rfs -vvv python/test/unit/instrumentation/test_gpuhello.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libTritonPluginsTestLib.so \
TRITON_PLUGIN_PATHS=python/triton/plugins/libTritonPluginsTestLib.so \
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -vvv python/test/unit/plugins/test_dialect_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
TRITON_PLUGIN_PATHS=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py

.PHONY: test-gluon
Expand Down
30 changes: 4 additions & 26 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,32 +149,10 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
mlir::triton::proton::gpu::registerAddSchedBarriersPass();

// Plugin passes
if (std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
!filename.empty()) {

TritonPlugin TP(filename);
std::vector<const char *> passNames;
if (auto result = TP.getPassHandles(passNames); !result)
llvm::report_fatal_error(result.takeError());

for (const char *passName : passNames)
if (auto result = TP.registerPass(passName); !result)
llvm::report_fatal_error(result.takeError());

std::vector<const char *> dialectNames;
if (auto result = TP.getDialectHandles(dialectNames); !result)
llvm::report_fatal_error(result.takeError());

for (unsigned i = 0; i < dialectNames.size(); ++i) {
const char *dialectName = dialectNames.data()[i];
auto result = TP.getDialectPluginInfo(dialectName);
if (!result)
llvm::report_fatal_error(result.takeError());
::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result;
dialectPluginInfo.registerDialectRegistryCallbacks(&registry);
}
// Register plugin passes and dialects.
for (const auto &plugin : mlir::triton::plugin::loadPlugins()) {
plugin.registerPasses();
plugin.registerDialects(registry);
}

registry.insert<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "DialectPlugin/DialectPluginTypes.h"

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::plugin;

#include "DialectPlugin/DialectPluginOpsDialect.cpp.inc"
Expand Down Expand Up @@ -30,9 +31,13 @@ void DialectPluginDialect::initialize() {
#include "triton/Tools/PluginUtils.h"
#include "llvm/Config/llvm-config.h"

using namespace mlir;
static const char *PLUGIN_NAME = "DialectPlugin";
static const char *DIALECT_NAME = "DialectPlugin";
static const char *PASS_NAME = "plugingpu_conversion";
static const char *VERSION = "0.1.0";

static void addTritonPluginPass(mlir::PassManager *pm) {
static void addTritonPluginPass(mlir::PassManager *pm,
const std::vector<std::string> &args) {
pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass());
}

Expand All @@ -42,88 +47,37 @@ static void registerTritonPluginPass() {
});
}

static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion";
static std::unordered_map<std::string, void (*)(mlir::PassManager *)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)()> registryMap = {
{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};

// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName,
const std::vector<std::string> &args) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonRegisterPluginPass(const char *passName) {
std::string passNameStr(passName);
if (registryMap.find(passNameStr) == registryMap.end())
return TP_GENERIC_FAILURE;
registryMap[passNameStr]();
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) {
if (!passCount)
return TP_GENERIC_FAILURE;
auto count = passMap.size();
assert(count == registryMap.size() &&
"Expected register and add passes map size to match");
*passCount = count;
if (!passNames)
return TP_SUCCESS;
unsigned i = 0;
for (auto passName : passNamesTable) {
passNames[i] = passName;
}
return TP_SUCCESS;
static void registerTritonPluginDialect(DialectRegistry *registry) {
registry->insert<mlir::triton::plugin::DialectPluginDialect>();
mlir::triton::plugin::registerpluginPasses();
}

TRITON_PLUGIN_API
tritonEnumeratePluginDialects(uint32_t *dialectCount,
const char **dialectNames) {
*dialectCount = 1;
if (!dialectNames)
return TP_SUCCESS;
dialectNames[0] = "DialectPlugin";
return TP_SUCCESS;
}

TRITON_PLUGIN_API_TYPE(DialectPluginLibraryInfo)
tritonGetDialectPluginInfo(const char *name) {
return {MLIR_PLUGIN_API_VERSION, "DialectPlugin", LLVM_VERSION_STRING,
[](DialectRegistry *registry) {
registry->insert<mlir::triton::plugin::DialectPluginDialect>();
mlir::triton::plugin::registerpluginPasses();
}};
}

TRITON_PLUGIN_API
tritonEnumeratePluginCustomOps(uint32_t *count, const char **handles) {
if (!count)
return TP_GENERIC_FAILURE;
*count = 1;
if (!handles)
return TP_SUCCESS;
handles[0] = "create_custom_op";
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonAddPluginCustomOp(const char *handle, TritonOpBuilder &self,
std::vector<mlir::Value> &operands) {
static void addTritonPluginCustomOp(TritonOpBuilder &self,
std::vector<mlir::Value> &operands) {
::mlir::Value &dst = operands[0];
::mlir::Value &src = operands[1];

dst = self.create<arith::AddFOp>(src, src);
operands[0] = dst;
return TP_SUCCESS;
}

TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass,
registerTritonPluginPass};
static plugin::PassInfo passes[] = {pass};
static plugin::DialectInfo dialect = {DIALECT_NAME, VERSION,
registerTritonPluginDialect};
static plugin::DialectInfo dialects[] = {dialect};
static plugin::OpInfo op = {"create_custom_op", addTritonPluginCustomOp};
static plugin::OpInfo ops[] = {op};
static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION,
PLUGIN_NAME,
VERSION,
passes,
1,
dialects,
1,
ops,
1};
return &info;
}
10 changes: 5 additions & 5 deletions examples/plugins/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ long as the libtriton.so is linked to the plugin and the Triton include passes a
## Example 1: Developing a custom pass and running triton-opt to inspect the modified IR
``` bash
export TRITON_EXT_ENABLED=1; make dev-install-llvm
TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir
TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so triton-opt -tritongpu-plugin test/Plugins/test-plugin.mlir
```
``` MLIR
module attributes {"ttg.num-warps" = 4 : i32, ttg.target = "cuda:80"} {
Expand Down Expand Up @@ -85,7 +85,7 @@ Running same code but loading the plugin library also produces the same results
pass manager it is not inserted into the compiler pass pipeline:

``` bash
TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
```

``` MLIR
Expand Down Expand Up @@ -151,7 +151,7 @@ if __name__ == '__main__':
h = kernel[grid](BLOCK_SIZE=1024)
print(h.asm["ttgir"])

if "TRITON_PASS_PLUGIN_PATH" in os.environ:
if "TRITON_PLUGIN_PATHS" in os.environ:
knobs.runtime.add_stages_inspection_hook = inspect_stages_hook
h = kernel[grid](BLOCK_SIZE=1024)
print(h.asm["ttgir"])
Expand All @@ -163,7 +163,7 @@ if __name__ == '__main__':
```

``` bash
TRITON_PASS_PLUGIN_PATH=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
TRITON_PLUGIN_PATHS=/home/triton/python/triton/plugins/libTritonPluginsTestLib.so python test.py
```

Shows the pass ran and modified the kernel name but only after the hook is set. Any kernels before the hook or after the hook is unset are left unchanged.
Expand Down Expand Up @@ -330,7 +330,7 @@ if __name__ == '__main__':
if "add_loop_unroll" in line:
outfile.write("\n passes.plugin.add_plugin(pm)\n")
outfile.write(line)
if "TRITON_PASS_PLUGIN_PATH" in os.environ:
if "TRITON_PLUGIN_PATHS" in os.environ:
knobs.runtime.add_stages_inspection_hook = override_stages
h = kernel2[grid](BLOCK_SIZE=1024)
print(h.asm["ttgir"])
Expand Down
63 changes: 20 additions & 43 deletions examples/plugins/TritonPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,47 +57,24 @@ static void registerTritonPluginPass() {
});
}

static const char *ADD_PLUGIN_PASS_NAME = "add_plugin";
static std::unordered_map<std::string, decltype(&addTritonPluginPass)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, decltype(&registerTritonPluginPass)>
registryMap = {{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};

// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName,
const std::vector<std::string> &args) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm, args);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonRegisterPluginPass(const char *passName) {
std::string passNameStr(passName);
if (registryMap.find(passNameStr) == registryMap.end())
return TP_GENERIC_FAILURE;
registryMap[passNameStr]();
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) {
if (!passCount)
return TP_GENERIC_FAILURE;
auto count = passMap.size();
assert(count == registryMap.size() &&
"Expected register and add passes map size to match");
*passCount = count;
if (!passNames)
return TP_SUCCESS;
unsigned i = 0;
for (auto passName : passNamesTable) {
passNames[i++] = passName;
}
return TP_SUCCESS;
static const char *PLUGIN_NAME = "TritonPlugin";
static const char *PASS_NAME = "add_plugin";
static const char *VERSION = "0.1.0";

using namespace mlir::triton;

TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass,
registerTritonPluginPass};
static plugin::PassInfo passes[] = {pass};
static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION,
PLUGIN_NAME,
VERSION,
passes,
1,
nullptr,
0,
nullptr,
0};
return &info;
}
Loading
Loading