Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "DialectPlugin/DialectPluginTypes.h"

using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::plugin;

#include "DialectPlugin/DialectPluginOpsDialect.cpp.inc"
Expand All @@ -25,13 +26,18 @@ void DialectPluginDialect::initialize() {

#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginPasses.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "triton/Tools/PluginUtils.h"
#include "llvm/Config/llvm-config.h"

using namespace mlir;
static const char *PLUGIN_NAME = "DialectPlugin";
static const char *DIALECT_NAME = "DialectPlugin";
static const char *PASS_NAME = "plugingpu_conversion";
static const char *VERSION = "0.1.0";

static void addTritonPluginPass(mlir::PassManager *pm) {
static void addTritonPluginPass(mlir::PassManager *pm,
const std::vector<std::string> &args) {
pm->addPass(mlir::triton::plugin::createConvertPluginGPUToLLVMPass());
}

Expand All @@ -41,65 +47,38 @@ static void registerTritonPluginPass() {
});
}

static const char *ADD_PLUGIN_PASS_NAME = "plugingpu_conversion";
static std::unordered_map<std::string, void (*)(mlir::PassManager *)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)()> registryMap = {
{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};

// Key APIs:

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonRegisterPluginPass(const char *passName) {
std::string passNameStr(passName);
if (registryMap.find(passNameStr) == registryMap.end())
return TP_GENERIC_FAILURE;
registryMap[passNameStr]();
return TP_SUCCESS;
static void registerTritonPluginDialect(DialectRegistry *registry) {
registry->insert<mlir::triton::plugin::DialectPluginDialect>();
mlir::triton::plugin::registerpluginPasses();
}

TRITON_PLUGIN_API
tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) {
if (!passCount)
return TP_GENERIC_FAILURE;
auto count = passMap.size();
assert(count == registryMap.size() &&
"Expected register and add passes map size to match");
*passCount = count;
if (!passNames)
return TP_SUCCESS;
unsigned i = 0;
for (auto passName : passNamesTable) {
passNames[i] = passName;
}
return TP_SUCCESS;
}
static void addTritonPluginCustomOp(TritonOpBuilder &self,
std::vector<mlir::Value> &operands) {
::mlir::Value &dst = operands[0];
::mlir::Value &src = operands[1];

TRITON_PLUGIN_API
tritonEnumeratePluginDialects(uint32_t *dialectCount,
const char **dialectNames) {
*dialectCount = 1;
if (!dialectNames)
return TP_SUCCESS;
dialectNames[0] = "DialectPlugin";
return TP_SUCCESS;
dst = self.create<arith::AddFOp>(src, src);
operands[0] = dst;
}

TRITON_PLUGIN_API_TYPE(DialectPluginLibraryInfo)
tritonGetDialectPluginInfo(const char *name) {
return {MLIR_PLUGIN_API_VERSION, "DialectPlugin", LLVM_VERSION_STRING,
[](DialectRegistry *registry) {
registry->insert<mlir::triton::plugin::DialectPluginDialect>();
mlir::triton::plugin::registerpluginPasses();
}};
TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass,
registerTritonPluginPass};
static plugin::PassInfo passes[] = {pass};
static plugin::DialectInfo dialect = {DIALECT_NAME, VERSION,
registerTritonPluginDialect};
static plugin::DialectInfo dialects[] = {dialect};
static plugin::OpInfo op = {"create_custom_op", addTritonPluginCustomOp};
static plugin::OpInfo ops[] = {op};
static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION,
PLUGIN_NAME,
VERSION,
passes,
1,
dialects,
1,
ops,
1,
TRITON_VERSION};
return &info;
}
83 changes: 40 additions & 43 deletions examples/plugins/TritonPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,27 @@ namespace mlir {
namespace triton {
namespace plugin {

#define GEN_PASS_DECL_TRITONGPUMLIRPLUGIN
#define GEN_PASS_DEF_TRITONGPUMLIRPLUGIN
#include "Passes.h.inc"

struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase<MLIRPluginPass> {
using TritonGPUMLIRPluginBase::TritonGPUMLIRPluginBase;

void runOnOperation() override {

MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

std::string name;
llvm::raw_string_ostream sstr(name);
sstr << "foo";
if (num_warps != 4)
sstr << "_num_warps_" << num_warps;

mod.walk([&](FunctionOpInterface funcOp) {
StringAttr funcNameAttr = funcOp.getNameAttr();
funcOp.setName("foo");
funcOp.setName(name);
});
}
};
Expand All @@ -29,8 +39,16 @@ struct MLIRPluginPass : public impl::TritonGPUMLIRPluginBase<MLIRPluginPass> {
} // namespace triton
} // namespace mlir

static void addTritonPluginPass(mlir::PassManager *pm) {
pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin());
static void addTritonPluginPass(mlir::PassManager *pm,
const std::vector<std::string> &args) {
if (args.empty()) {
pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin());
return;
}

mlir::triton::plugin::TritonGPUMLIRPluginOptions opts;
opts.num_warps = std::atoi(args[0].c_str());
pm->addPass(mlir::triton::plugin::createTritonGPUMLIRPlugin((opts)));
}

static void registerTritonPluginPass() {
Expand All @@ -39,46 +57,25 @@ static void registerTritonPluginPass() {
});
}

static const char *ADD_PLUGIN_PASS_NAME = "add_plugin";
static std::unordered_map<std::string, void (*)(mlir::PassManager *)> passMap =
{{ADD_PLUGIN_PASS_NAME, addTritonPluginPass}};
static std::unordered_map<std::string, void (*)()> registryMap = {
{ADD_PLUGIN_PASS_NAME, registerTritonPluginPass}};
static std::vector<const char *> passNamesTable = {ADD_PLUGIN_PASS_NAME};
static const char *PLUGIN_NAME = "TritonPlugin";
static const char *PASS_NAME = "add_plugin";
static const char *VERSION = "0.1.0";

// Key APIs:
using namespace mlir::triton;

TRITON_PLUGIN_API
tritonAddPluginPass(mlir::PassManager *pm, const char *passName) {
std::string passNameStr(passName);
if (passMap.find(passNameStr) == passMap.end())
return TP_GENERIC_FAILURE;
passMap[passNameStr](pm);
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonRegisterPluginPass(const char *passName) {
std::string passNameStr(passName);
if (registryMap.find(passNameStr) == registryMap.end())
return TP_GENERIC_FAILURE;
registryMap[passNameStr]();
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonEnumeratePluginPasses(uint32_t *passCount, const char **passNames) {
if (!passCount)
return TP_GENERIC_FAILURE;
auto count = passMap.size();
assert(count == registryMap.size() &&
"Expected register and add passes map size to match");
*passCount = count;
if (!passNames)
return TP_SUCCESS;
unsigned i = 0;
for (auto passName : passNamesTable) {
passNames[i++] = passName;
}
return TP_SUCCESS;
TRITON_PLUGIN_API plugin::PluginInfo *tritonGetPluginInfo() {
static plugin::PassInfo pass = {PASS_NAME, VERSION, addTritonPluginPass,
registerTritonPluginPass};
static plugin::PassInfo passes[] = {pass};
static plugin::PluginInfo info = {TRITON_PLUGIN_API_VERSION,
PLUGIN_NAME,
VERSION,
passes,
1,
nullptr,
0,
nullptr,
0,
TRITON_VERSION};
return &info;
}
Loading
Loading