diff --git a/Makefile b/Makefile index a2eb1a5911ed..4952736ed224 100644 --- a/Makefile +++ b/Makefile @@ -44,6 +44,8 @@ test-unit: all $(PYTEST) -vvv python/test/unit/plugins/test_plugin.py TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \ $(PYTEST) -vvv python/test/unit/plugins/test_dialect_plugin.py + TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \ + $(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py .PHONY: test-gluon test-gluon: all diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt index 2e0271800053..14e3077d565b 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/CMakeLists.txt @@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRDialectPlugin MLIRPass LLVMSupport MLIRSupport + MLIRArithDialect TritonNVIDIAGPUToLLVM "$<$:-undefined dynamic_lookup>" ) diff --git a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp index 4748f8fbe848..b605b92de1b1 100644 --- a/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp +++ b/examples/plugins/DialectPlugins/DialectPlugin/lib/DialectPlugin/DialectPluginDialect.cpp @@ -25,6 +25,7 @@ void DialectPluginDialect::initialize() { #include "DialectPlugin/DialectPluginDialect.h" #include "DialectPlugin/DialectPluginPasses.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Tools/Plugins/PassPlugin.h" #include "triton/Tools/PluginUtils.h" #include "llvm/Config/llvm-config.h" @@ -104,3 +105,25 @@ tritonGetDialectPluginInfo(const char *name) { mlir::triton::plugin::registerpluginPasses(); }}; } + +TRITON_PLUGIN_API +tritonEnumeratePluginCustomOps(uint32_t *count, const char **handles) { + if (!count) + return TP_GENERIC_FAILURE; + *count = 1; + if (!handles) + return TP_SUCCESS; + handles[0] = "create_custom_op"; + return TP_SUCCESS; +} + +TRITON_PLUGIN_API +tritonAddPluginCustomOp(const char *handle, TritonOpBuilder &self, + std::vector &operands) { + ::mlir::Value &dst = operands[0]; + ::mlir::Value &src = operands[1]; + + dst = self.create(src, src); + operands[0] = dst; + return TP_SUCCESS; +} diff --git a/include/triton/Tools/PluginUtils.h b/include/triton/Tools/PluginUtils.h index 77b921ae6b48..86536948105b 100644 --- a/include/triton/Tools/PluginUtils.h +++ b/include/triton/Tools/PluginUtils.h @@ -3,6 +3,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Tools/Plugins/DialectPlugin.h" +#include "python/src/ir.h" #include "llvm/Support/DynamicLibrary.h" #include "llvm/Support/Error.h" #include @@ -29,6 +30,9 @@ struct TritonPlugin { static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo"; static constexpr char ADD_PASS[] = "tritonAddPluginPass"; static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass"; + static constexpr char ENUMERATE_CUSTOMOPS[] = + "tritonEnumeratePluginCustomOps"; + static constexpr char ADD_CUSTOMOP[] = "tritonAddPluginCustomOp"; private: using EnumeratePyBindHandlesType = @@ -49,6 +53,13 @@ struct TritonPlugin { using DialectPluginInfoCType = ::mlir::DialectPluginLibraryInfo (*)(const char *); + using AddCustomOpType = std::function &operands)>; + using AddCustomOpCType = + TritonPluginResult (*)(const char *handle, TritonOpBuilder &self, + std::vector &operands); + llvm::Expected getAddressOfSymbol(const std::string &symbol) const; template @@ -78,10 +89,17 @@ struct TritonPlugin { llvm::Expected getDialectHandles(std::vector &handles); + llvm::Expected + getCustomOpHandles(std::vector &handles); + llvm::Expected addPass(mlir::PassManager *pm, const char *passHandle, const std::vector &args); + llvm::Expected + addCustomOp(const char *handle, TritonOpBuilder &self, + std::vector &operands); + llvm::Expected registerPass(const char *passHandle); llvm::Expected<::mlir::DialectPluginLibraryInfo> @@ -92,10 +110,15 @@ struct TritonPlugin { mutable llvm::sys::DynamicLibrary library; EnumeratePyBindHandlesType enumeratePassesAPI; EnumeratePyBindHandlesType enumerateDialectsAPI; + EnumeratePyBindHandlesType enumerateCustomOpAPI; AddPassType addPassAPI; RegisterPassType registerPassAPI; DialectPluginInfoType dialectPluginInfoAPI; + AddCustomOpType addCustomOpAPI; bool isLoaded = false; }; +void loadPluginDialects(const std::string &filename, + mlir::DialectRegistry ®istry); + #endif // TRITON_PLUGIN_UTILS_H diff --git a/lib/Tools/PluginUtils.cpp b/lib/Tools/PluginUtils.cpp index f9ef78d9f2e8..19f362801f35 100644 --- a/lib/Tools/PluginUtils.cpp +++ b/lib/Tools/PluginUtils.cpp @@ -83,6 +83,22 @@ llvm::Error TritonPlugin::loadPlugin() { dialectPluginInfoAPI = *dialectPluginInfoAPIOrErr; } + if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_CUSTOMOPS)) { + auto enumerateCustomOpAPIOrErr = + getAPI( + ENUMERATE_CUSTOMOPS); + auto addCustomOpAPIOrErr = + getAPI(ADD_CUSTOMOP); + + if (auto Err = enumerateCustomOpAPIOrErr.takeError()) + return Err; + if (auto Err = addCustomOpAPIOrErr.takeError()) + return Err; + + enumerateCustomOpAPI = *enumerateCustomOpAPIOrErr; + addCustomOpAPI = *addCustomOpAPIOrErr; + } + isLoaded = true; return llvm::Error::success(); } @@ -137,6 +153,19 @@ TritonPlugin::getDialectHandles(std::vector &dialectNames) { return enumeratePyBindHandles(enumerateDialectsAPI, dialectNames); } +llvm::Expected +TritonPlugin::getCustomOpHandles(std::vector &customOpNames) { + if (auto Err = loadPlugin()) + return Err; + // Do a check to see if the enumerate-custom-ops api symbol is present, bail + // as if there are 0 custom ops if not + intptr_t isCustomOpSymbolPresent = + (intptr_t)library.getAddressOfSymbol(ENUMERATE_CUSTOMOPS); + if (!isCustomOpSymbolPresent) + return TP_SUCCESS; + return enumeratePyBindHandles(enumerateCustomOpAPI, customOpNames); +} + llvm::Expected TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle, const std::vector &args) { @@ -158,3 +187,30 @@ TritonPlugin::getDialectPluginInfo(const char *dialectName) { return Err; return dialectPluginInfoAPI(dialectName); } + +llvm::Expected +TritonPlugin::addCustomOp(const char *handle, TritonOpBuilder &self, + std::vector &operands) { + if (auto Err = loadPlugin()) + return Err; + addCustomOpAPI(handle, self, operands); + return TP_SUCCESS; +} + +void loadPluginDialects(const std::string &filename, + mlir::DialectRegistry ®istry) { + TritonPlugin TP(filename); + + std::vector dialectNames; + if (auto result = TP.getDialectHandles(dialectNames); !result) + llvm::report_fatal_error(result.takeError()); + + for (unsigned i = 0; i < dialectNames.size(); ++i) { + const char *dialectName = dialectNames.data()[i]; + auto result = TP.getDialectPluginInfo(dialectName); + if (!result) + llvm::report_fatal_error(result.takeError()); + ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; + dialectPluginInfo.registerDialectRegistryCallbacks(®istry); + } +} diff --git a/python/src/ir.cc b/python/src/ir.cc index 74a62573ec21..243410935e5b 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -231,10 +231,40 @@ py::list getTensorDescMetadata(ModuleOp &mod) { } // anonymous namespace +static void +registerCustomOps(py::class_ &TritonOpBuilderBinding, + const std::string &filename) { + TritonPlugin TP(filename); + std::vector customOpNames; + if (auto result = TP.getCustomOpHandles(customOpNames); !result) + throw TP.err2exp(result.takeError()); + + for (unsigned i = 0; i < customOpNames.size(); ++i) { + const char *customOpName = customOpNames.data()[i]; + + TritonOpBuilderBinding.def( + customOpName, + [customOpName](TritonOpBuilder &self, + std::vector &args) -> mlir::Value { + std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + TritonPlugin TP(filename); + + ::mlir::Value dst; + std::vector<::mlir::Value> values = {dst}; + llvm::copy(args, std::back_inserter(values)); + auto result = TP.addCustomOp(customOpName, self, values); + if (!result) + throw TP.err2exp(result.takeError()); + dst = values[0]; + return dst; + }); + } +} + /*****************************************************************************/ /* Python bindings for ir */ /*****************************************************************************/ - void init_triton_ir(py::module &&m) { using ret = py::return_value_policy; using namespace pybind11::literals; @@ -342,20 +372,7 @@ void init_triton_ir(py::module &&m) { if (std::string filename = mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); !filename.empty()) { - TritonPlugin TP(filename); - - std::vector dialectNames; - if (auto result = TP.getDialectHandles(dialectNames); !result) - llvm::report_fatal_error(result.takeError()); - - for (unsigned i = 0; i < dialectNames.size(); ++i) { - const char *dialectName = dialectNames.data()[i]; - auto result = TP.getDialectPluginInfo(dialectName); - if (!result) - throw TP.err2exp(result.takeError()); - ::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result; - dialectPluginInfo.registerDialectRegistryCallbacks(®istry); - } + loadPluginDialects(filename, registry); } registry.insert(m, "InsertPoint", py::module_local()); - py::class_(m, "builder", py::module_local(), - py::dynamic_attr()) - .def(py::init()) + py::class_ TritonOpBuilderBinding = + py::class_(m, "builder", py::module_local(), + py::dynamic_attr()); + TritonOpBuilderBinding.def(py::init()) .def("get_op_builder", &TritonOpBuilder::getBuilder, ret::reference) // getters .def("create_module", @@ -1174,7 +1192,8 @@ void init_triton_ir(py::module &&m) { }) // Cast instructions - // Conversions for custom FP types (FP8 and non-standard rounding modes) + // Conversions for custom FP types (FP8 and non-standard rounding + // modes) .def("create_fp_to_fp", [](TritonOpBuilder &self, Value &src, Type &dstType, std::optional roundingMode) -> Value { @@ -1315,8 +1334,8 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); }) - // minimumf follows the torch.minimum convention and returns NaN if either - // operand is NaN + // minimumf follows the torch.minimum convention and returns NaN if + // either operand is NaN .def("create_minimumf", [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); @@ -1335,8 +1354,8 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); }) - // maximumf follows the torch.maximum convention and returns NaN if either - // operand is NaN + // maximumf follows the torch.maximum convention and returns NaN if + // either operand is NaN .def("create_maximumf", [](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value { return Value(self.create(lhs, rhs)); @@ -1844,6 +1863,12 @@ void init_triton_ir(py::module &&m) { paddingOption); }); + if (std::string filename = + mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH"); + !filename.empty()) { + registerCustomOps(TritonOpBuilderBinding, filename); + } + py::class_(m, "pass_manager", py::module_local()) .def(py::init()) .def("enable_debug", diff --git a/python/src/ir.h b/python/src/ir.h index 499dd9e8a9f6..f8dd9b2941ac 100644 --- a/python/src/ir.h +++ b/python/src/ir.h @@ -1,7 +1,9 @@ #pragma once #include "mlir/IR/Builders.h" +#include "mlir/IR/DialectRegistry.h" #include "triton/Tools/Sys/GetEnv.hpp" #include +#include // A custom op builder that keeps track of the last location class TritonOpBuilder { diff --git a/python/test/unit/plugins/custom_ops.py b/python/test/unit/plugins/custom_ops.py new file mode 100644 index 000000000000..5ce6c4fa9237 --- /dev/null +++ b/python/test/unit/plugins/custom_ops.py @@ -0,0 +1,72 @@ +import torch + +import triton +import triton.language as tl +from triton._C.libtriton import ir +from triton.language.core import builtin +from typing import TypeVar, Type +import builtins +import os +import pathlib +from triton.compiler.code_generator import flatten_values_to_ir + +T = TypeVar('T') +TensorTy = TypeVar('TensorTy') + +triton.language.__all__.append("custom_op") +tensor: Type[TensorTy] = tl.tensor +builder: ir.builder + +TRITON_BUILTIN = "__triton_builtin__" + + +def _unwrap_if_constexpr(o): + if isinstance(o, list): + return [_unwrap_if_constexpr(x) for x in o] + if isinstance(o, builtins.tuple): + return builtins.tuple(_unwrap_if_constexpr(x) for x in o) + if isinstance(o, tuple): + return tuple(_unwrap_if_constexpr(x) for x in o) + return o.value if isinstance(o, tl.constexpr) else o + + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@builtin +def custom_op(x, sanitize_overflow: tl.constexpr = True, _semantic=None): + x = _unwrap_if_constexpr(x) + builder = _semantic.builder + arg_handles = [] + arg_handles.extend(flatten_values_to_ir([x])) + return tl.tensor(builder.create_custom_op(arg_handles), x.type) + + +@triton.jit +def add_kernel( + x_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = custom_op(x) + tl.store(output_ptr + offsets, output, mask=mask) + + +def test_custom_ops(tmp_path: pathlib.Path): + if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0': + return + size = 8 + x = torch.zeros(size, device=DEVICE, dtype=torch.float32) + output_triton = torch.empty_like(x) + n_elements = output_triton.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + h = add_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=32) + + src = h.asm["source"] + assert "arith.addf" in src