Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};
// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName) {
tritonAddPluginPass(mlir::PassManager *pm, const char *passName,
const std::vector<std::string> &args) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
Expand Down
7 changes: 7 additions & 0 deletions examples/plugins/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,12 @@ include "mlir/Pass/PassBase.td"

def TritonGPUMLIRPlugin : Pass<"tritongpu-plugin", "mlir::ModuleOp"> {
let summary = "Triton MLIR Plugin Pass";

let options = [
Option<"num_warps", "num-warps",
"int32_t", /*default*/"4",
"Number of warps">,
];

}
#endif
35 changes: 27 additions & 8 deletions examples/plugins/TritonPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,27 @@ namespace mlir {
namespace triton {
namespace plugin {

#define GEN_PASS_DECL_TRITONGPUMLIRPLUGIN
#define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN
#include "Passes.h.inc"

struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase<MLIRPluginPass> {
using TritonGPUMLIRPluginBase::TritonGPUMLIRPluginBase;

void runOnOperation() override {

MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

std::string name;
llvm::raw_string_ostream sstr(name);
sstr << "foo";
if (num_warps != 4)
sstr << "_num_warps_" << num_warps;

mod.walk([&](FunctionOpInterface funcOp) {
StringAttr funcNameAttr = funcOp.getNameAttr();
funcOp.setName("foo");
funcOp.setName(name);
});
}
};
Expand All @@ -29,8 +39,16 @@ struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase<MLIRPluginPass> {
} // namespace triton
} // namespace mlir

static void addTritonPluginPass(mlir::PassManager *pm) {
pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin());
static void addTritonPluginPass(mlir::PassManager *pm,
const std::vector<std::string> &args) {
if (args.empty()) {
pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin());
return;
}

mlir::triton::plugin::TritonGPUMLIRPluginOptions opts;
opts.num_warps = std::atoi(args[0].c_str());
pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin((opts)));
}

static void registerTritonPluginPass() {
Expand All @@ -40,20 +58,21 @@ static void registerTritonPluginPass() {
}

static const char *ADD_PLUGIN_PASS_NAME = "add_plugin";
static std::unordered_map<std::string, void (*)(mlir::PassManager *)> passMap =
static std::unordered_map<std::string, decltype(&addTritonPluginPass)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)()> registryMap = {
{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::unordered_map<std::string, decltype(&registerTritonPluginPass)>
registryMap = {{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};

// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName) {
tritonAddPluginPass(mlir::PassManager *pm, const char *passName,
const std::vector<std::string> &args) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm);
passMap[passNameStr](pm, args);
return TP_SUCCESS;
}

Expand Down
13 changes: 7 additions & 6 deletions include/triton/Tools/PluginUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ struct TritonPlugin {
using EnumeratePyBindHandlesCType = TritonPluginResult (*)(uint32_t *,
const char **);

using AddPassType =
std::function<TritonPluginResult(mlir::PassManager *, const char *)>;
using AddPassCType = TritonPluginResult (*)(mlir::PassManager *,
const char *);
using AddPassType = std::function<TritonPluginResult(
mlir::PassManager *, const char *, const std::vector<std::string> &)>;
using AddPassCType = TritonPluginResult (*)(mlir::PassManager *, const char *,
const std::vector<std::string> &);

using RegisterPassType = std::function<TritonPluginResult(const char *)>;
using RegisterPassCType = TritonPluginResult (*)(const char *);
Expand Down Expand Up @@ -78,8 +78,9 @@ struct TritonPlugin {
llvm::Expected<TritonPluginResult>
getDialectHandles(std::vector<const char *> &handles);

llvm::Expected<TritonPluginResult> addPass(mlir::PassManager *pm,
const char *passHandle);
llvm::Expected<TritonPluginResult>
addPass(mlir::PassManager *pm, const char *passHandle,
const std::vector<std::string> &args);

llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);

Expand Down
5 changes: 3 additions & 2 deletions lib/Tools/PluginUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ TritonPlugin::getDialectHandles(std::vector<const char *> &dialectNames) {
}

llvm::Expected<TritonPluginResult>
TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle) {
TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle,
const std::vector<std::string> &args) {
if (auto Err = loadPlugin())
return Err;
return checkAPIResult(addPassAPI(pm, passHandle), passHandle);
return checkAPIResult(addPassAPI(pm, passHandle, args), passHandle);
}

llvm::Expected<TritonPluginResult>
Expand Down
17 changes: 10 additions & 7 deletions python/src/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,16 @@ void init_plugin_passes(py::module &&m) {
for (unsigned i = 0; i < passNames.size(); ++i) {
const char *passName = passNames.data()[i];

m.def(passName, [passName](mlir ::PassManager &pm) {
std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
TritonPlugin TP(filename);
if (auto result = TP.addPass(&pm, passName); !result)
throw TP.err2exp(result.takeError());
});
m.def(
passName,
[passName](mlir ::PassManager &pm, std::vector<std::string> args) {
std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
TritonPlugin TP(filename);
if (auto result = TP.addPass(&pm, passName, args); !result)
throw TP.err2exp(result.takeError());
},
py::arg("pm"), py::arg("args") = std::vector<std::string>());
}
}

Expand Down
7 changes: 6 additions & 1 deletion python/test/unit/plugins/custom_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def get_hash():
# Keep custom pipeline stages in a seperate file from kernels as any change to the file
# will trigger a recompile.

num_warps = 4


def inspect_stages_hook(self=None, stages=None, options=None, language=None, capability=None):
# If the hook is called with no arguments we assume were just after the key and hash and don't want to
Expand All @@ -31,7 +33,10 @@ def make_ttir_wrapper(mod, metadata, opt, capability):
mod = self.make_ttir(mod, metadata, opt, capability)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.plugin.add_plugin(pm)
if num_warps != 4:
passes.plugin.add_plugin(pm, {str(num_warps)})
else:
passes.plugin.add_plugin(pm)
pm.run(mod, 'make_ttir_plugin')
return mod

Expand Down
11 changes: 11 additions & 0 deletions python/test/unit/plugins/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def kernel2(BLOCK_SIZE: tl.constexpr):
return


@pytest.mark.parametrize(None, [None])
@triton.jit
def kernel3(BLOCK_SIZE: tl.constexpr):
return


def test_op(capfd, device: str):
if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0':
return
Expand All @@ -41,3 +47,8 @@ def test_op(capfd, device: str):
knobs.runtime.add_stages_inspection_hook = None
h = kernel2[grid](BLOCK_SIZE=1024)
assert "tt.func public @foo" not in h.asm["ttgir"]

knobs.runtime.add_stages_inspection_hook = custom_stages.inspect_stages_hook
custom_stages.num_warps = 8
h = kernel3[grid](BLOCK_SIZE=1024)
assert "tt.func public @foo_num_warps_8" in h.asm["ttgir"]
Loading