Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
35a1a64
[triton-ext] Reapply custom-ops changes for Triton Extensions
CRobeck Feb 28, 2026
a199240
turn custom_ops.py into a pytest
plotfi Mar 1, 2026
ad3a1cb
update test
CRobeck Mar 3, 2026
b2afcc3
Add import for os module in custom_ops.py
CRobeck Mar 3, 2026
2d4d995
more clean up
CRobeck Mar 4, 2026
729b5ad
Make custom ops callable with variable arguments
plotfi Mar 4, 2026
fc3f367
lint
CRobeck Mar 4, 2026
d4cfb54
clang-format toggle to prevent excessive changes due to pre-commit
plotfi Mar 5, 2026
9aa0f2e
lint
plotfi Mar 5, 2026
a1d955d
adding add pass arguments for things like num_warps, num_stages etc
plotfi Mar 5, 2026
f7f8af5
lint
plotfi Mar 5, 2026
21a2aae
make 0 param passes look and feel normal for adding to pipeline
plotfi Mar 5, 2026
e743b35
refactor plugin arg intrfaces to have a simple point of truth
plotfi Mar 5, 2026
cbd9e33
drop _count
plotfi Mar 5, 2026
be14d33
fix dialect count bug
plotfi Mar 5, 2026
9dad391
drop builtin since it comes from triton core
plotfi Mar 5, 2026
e3d7a86
clean up custom ops interface
plotfi Mar 5, 2026
11f4cc4
clean up custom ops interface 2
plotfi Mar 5, 2026
96a4381
move opt arg list to stringified list
plotfi Mar 5, 2026
9c0eeef
drop _count
plotfi Mar 5, 2026
4b76d1a
cleanup addCustomOp
plotfi Mar 5, 2026
8927563
oops, fix out of bounds error for no args passed
plotfi Mar 5, 2026
11126d7
clean up formatting
CRobeck Mar 6, 2026
b3bfde6
drop test bloat
plotfi Mar 10, 2026
6224224
undo NFC changes for API interface macros
plotfi Mar 10, 2026
157d710
Fix failure in python/test/unit/plugins/test_plugin.py when custom op…
plotfi Mar 10, 2026
243af4b
further simplification
plotfi Mar 11, 2026
6cb3788
drop getBuilder()
plotfi Mar 11, 2026
99610dd
clang-format off removed + lint
plotfi Mar 11, 2026
49acef8
improve whitespace
plotfi Mar 13, 2026
d671287
Merge branch 'main' into custom_dialect_ops
CRobeck Mar 15, 2026
059fa21
Merge branch 'main' into custom_dialect_ops
CRobeck Mar 16, 2026
e5c6ef6
Remove unused functions from custom_ops.py
CRobeck Mar 18, 2026
1864dc6
address review comments
CRobeck Mar 18, 2026
db91a8d
address review comments
CRobeck Mar 18, 2026
7d03b4c
Merge branch 'main' into custom_dialect_ops
CRobeck Mar 18, 2026
b0665a4
Merge branch 'main' into custom_dialect_ops
plotfi Mar 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ test-unit: all
$(PYTEST) -vvv python/test/unit/plugins/test_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -vvv python/test/unit/plugins/test_dialect_plugin.py
TRITON_PASS_PLUGIN_PATH=python/triton/plugins/libMLIRDialectPlugin.so \
$(PYTEST) -s -vvv python/test/unit/plugins/custom_ops.py

.PHONY: test-gluon
test-gluon: all
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRDialectPlugin
MLIRPass
LLVMSupport
MLIRSupport
MLIRArithDialect
TritonNVIDIAGPUToLLVM
"$<$<PLATFORM_ID:Darwin>:-undefined dynamic_lookup>"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void DialectPluginDialect::initialize() {

#include "DialectPlugin/DialectPluginDialect.h"
#include "DialectPlugin/DialectPluginPasses.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Tools/Plugins/PassPlugin.h"
#include "triton/Tools/PluginUtils.h"
#include "llvm/Config/llvm-config.h"
Expand Down Expand Up @@ -104,3 +105,25 @@ tritonGetDialectPluginInfo(const char *name) {
mlir::triton::plugin::registerpluginPasses();
}};
}

TRITON_PLUGIN_API
tritonEnumeratePluginCustomOps(uint32_t *count, const char **handles) {
if (!count)
return TP_GENERIC_FAILURE;
*count = 1;
if (!handles)
return TP_SUCCESS;
handles[0] = "create_custom_op";
return TP_SUCCESS;
}

TRITON_PLUGIN_API
tritonAddPluginCustomOp(const char *handle, TritonOpBuilder &self,
std::vector<mlir::Value> &operands) {
::mlir::Value &dst = operands[0];
::mlir::Value &src = operands[1];

dst = self.create<arith::AddFOp>(src, src);
operands[0] = dst;
return TP_SUCCESS;
}
23 changes: 23 additions & 0 deletions include/triton/Tools/PluginUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlir/Pass/PassManager.h"
#include "mlir/Tools/Plugins/DialectPlugin.h"
#include "python/src/ir.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/Error.h"
#include <cstdint>
Expand All @@ -29,6 +30,9 @@ struct TritonPlugin {
static constexpr char DIALECT_PLUGININFO[] = "tritonGetDialectPluginInfo";
static constexpr char ADD_PASS[] = "tritonAddPluginPass";
static constexpr char REGISTER_PASS[] = "tritonRegisterPluginPass";
static constexpr char ENUMERATE_CUSTOMOPS[] =
"tritonEnumeratePluginCustomOps";
static constexpr char ADD_CUSTOMOP[] = "tritonAddPluginCustomOp";

private:
using EnumeratePyBindHandlesType =
Expand All @@ -49,6 +53,13 @@ struct TritonPlugin {
using DialectPluginInfoCType =
::mlir::DialectPluginLibraryInfo (*)(const char *);

using AddCustomOpType = std::function<TritonPluginResult(
const char *handle, TritonOpBuilder &self,
std::vector<mlir::Value> &operands)>;
using AddCustomOpCType =
TritonPluginResult (*)(const char *handle, TritonOpBuilder &self,
std::vector<mlir::Value> &operands);

llvm::Expected<intptr_t> getAddressOfSymbol(const std::string &symbol) const;

template <typename T, typename U>
Expand Down Expand Up @@ -78,10 +89,17 @@ struct TritonPlugin {
llvm::Expected<TritonPluginResult>
getDialectHandles(std::vector<const char *> &handles);

llvm::Expected<TritonPluginResult>
getCustomOpHandles(std::vector<const char *> &handles);

llvm::Expected<TritonPluginResult>
addPass(mlir::PassManager *pm, const char *passHandle,
const std::vector<std::string> &args);

llvm::Expected<TritonPluginResult>
addCustomOp(const char *handle, TritonOpBuilder &self,
std::vector<mlir::Value> &operands);

llvm::Expected<TritonPluginResult> registerPass(const char *passHandle);

llvm::Expected<::mlir::DialectPluginLibraryInfo>
Expand All @@ -92,10 +110,15 @@ struct TritonPlugin {
mutable llvm::sys::DynamicLibrary library;
EnumeratePyBindHandlesType enumeratePassesAPI;
EnumeratePyBindHandlesType enumerateDialectsAPI;
EnumeratePyBindHandlesType enumerateCustomOpAPI;
AddPassType addPassAPI;
RegisterPassType registerPassAPI;
DialectPluginInfoType dialectPluginInfoAPI;
AddCustomOpType addCustomOpAPI;
bool isLoaded = false;
};

void loadPluginDialects(const std::string &filename,
mlir::DialectRegistry &registry);

#endif // TRITON_PLUGIN_UTILS_H
56 changes: 56 additions & 0 deletions lib/Tools/PluginUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ llvm::Error TritonPlugin::loadPlugin() {
dialectPluginInfoAPI = *dialectPluginInfoAPIOrErr;
}

if ((intptr_t)library.getAddressOfSymbol(ENUMERATE_CUSTOMOPS)) {
auto enumerateCustomOpAPIOrErr =
getAPI<EnumeratePyBindHandlesType, EnumeratePyBindHandlesCType>(
ENUMERATE_CUSTOMOPS);
auto addCustomOpAPIOrErr =
getAPI<AddCustomOpType, AddCustomOpCType>(ADD_CUSTOMOP);

if (auto Err = enumerateCustomOpAPIOrErr.takeError())
return Err;
if (auto Err = addCustomOpAPIOrErr.takeError())
return Err;

enumerateCustomOpAPI = *enumerateCustomOpAPIOrErr;
addCustomOpAPI = *addCustomOpAPIOrErr;
}

isLoaded = true;
return llvm::Error::success();
}
Expand Down Expand Up @@ -137,6 +153,19 @@ TritonPlugin::getDialectHandles(std::vector<const char *> &dialectNames) {
return enumeratePyBindHandles(enumerateDialectsAPI, dialectNames);
}

llvm::Expected<TritonPluginResult>
TritonPlugin::getCustomOpHandles(std::vector<const char *> &customOpNames) {
if (auto Err = loadPlugin())
return Err;
// Do a check to see if the enumerate-custom-ops api symbol is present, bail
// as if there are 0 custom ops if not
intptr_t isCustomOpSymbolPresent =
(intptr_t)library.getAddressOfSymbol(ENUMERATE_CUSTOMOPS);
if (!isCustomOpSymbolPresent)
return TP_SUCCESS;
return enumeratePyBindHandles(enumerateCustomOpAPI, customOpNames);
}

llvm::Expected<TritonPluginResult>
TritonPlugin::addPass(mlir::PassManager *pm, const char *passHandle,
const std::vector<std::string> &args) {
Expand All @@ -158,3 +187,30 @@ TritonPlugin::getDialectPluginInfo(const char *dialectName) {
return Err;
return dialectPluginInfoAPI(dialectName);
}

llvm::Expected<TritonPluginResult>
TritonPlugin::addCustomOp(const char *handle, TritonOpBuilder &self,
std::vector<mlir::Value> &operands) {
if (auto Err = loadPlugin())
return Err;
addCustomOpAPI(handle, self, operands);
return TP_SUCCESS;
}

void loadPluginDialects(const std::string &filename,
mlir::DialectRegistry &registry) {
TritonPlugin TP(filename);

std::vector<const char *> dialectNames;
if (auto result = TP.getDialectHandles(dialectNames); !result)
llvm::report_fatal_error(result.takeError());

for (unsigned i = 0; i < dialectNames.size(); ++i) {
const char *dialectName = dialectNames.data()[i];
auto result = TP.getDialectPluginInfo(dialectName);
if (!result)
llvm::report_fatal_error(result.takeError());
::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result;
dialectPluginInfo.registerDialectRegistryCallbacks(&registry);
}
}
71 changes: 48 additions & 23 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,40 @@ py::list getTensorDescMetadata(ModuleOp &mod) {

} // anonymous namespace

static void
registerCustomOps(py::class_<TritonOpBuilder> &TritonOpBuilderBinding,
const std::string &filename) {
TritonPlugin TP(filename);
std::vector<const char *> customOpNames;
if (auto result = TP.getCustomOpHandles(customOpNames); !result)
throw TP.err2exp(result.takeError());

for (unsigned i = 0; i < customOpNames.size(); ++i) {
const char *customOpName = customOpNames.data()[i];

TritonOpBuilderBinding.def(
customOpName,
[customOpName](TritonOpBuilder &self,
std::vector<mlir::Value> &args) -> mlir::Value {
std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
TritonPlugin TP(filename);

::mlir::Value dst;
std::vector<::mlir::Value> values = {dst};
llvm::copy(args, std::back_inserter(values));
auto result = TP.addCustomOp(customOpName, self, values);
if (!result)
throw TP.err2exp(result.takeError());
dst = values[0];
return dst;
});
}
}

/*****************************************************************************/
/* Python bindings for ir */
/*****************************************************************************/

void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy;
using namespace pybind11::literals;
Expand Down Expand Up @@ -342,20 +372,7 @@ void init_triton_ir(py::module &&m) {
if (std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
!filename.empty()) {
TritonPlugin TP(filename);

std::vector<const char *> dialectNames;
if (auto result = TP.getDialectHandles(dialectNames); !result)
llvm::report_fatal_error(result.takeError());

for (unsigned i = 0; i < dialectNames.size(); ++i) {
const char *dialectName = dialectNames.data()[i];
auto result = TP.getDialectPluginInfo(dialectName);
if (!result)
throw TP.err2exp(result.takeError());
::mlir::DialectPluginLibraryInfo dialectPluginInfo = *result;
dialectPluginInfo.registerDialectRegistryCallbacks(&registry);
}
loadPluginDialects(filename, registry);
}

registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
Expand Down Expand Up @@ -823,9 +840,10 @@ void init_triton_ir(py::module &&m) {

py::class_<OpBuilder::InsertPoint>(m, "InsertPoint", py::module_local());

py::class_<TritonOpBuilder>(m, "builder", py::module_local(),
py::dynamic_attr())
.def(py::init<MLIRContext *>())
py::class_<TritonOpBuilder> TritonOpBuilderBinding =
py::class_<TritonOpBuilder>(m, "builder", py::module_local(),
py::dynamic_attr());
TritonOpBuilderBinding.def(py::init<MLIRContext *>())
.def("get_op_builder", &TritonOpBuilder::getBuilder, ret::reference)
// getters
.def("create_module",
Expand Down Expand Up @@ -1174,7 +1192,8 @@ void init_triton_ir(py::module &&m) {
})

// Cast instructions
// Conversions for custom FP types (FP8 and non-standard rounding modes)
// Conversions for custom FP types (FP8 and non-standard rounding
// modes)
.def("create_fp_to_fp",
[](TritonOpBuilder &self, Value &src, Type &dstType,
std::optional<RoundingMode> roundingMode) -> Value {
Expand Down Expand Up @@ -1315,8 +1334,8 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MinUIOp>(lhs, rhs));
})
// minimumf follows the torch.minimum convention and returns NaN if either
// operand is NaN
// minimumf follows the torch.minimum convention and returns NaN if
// either operand is NaN
.def("create_minimumf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MinimumFOp>(lhs, rhs));
Expand All @@ -1335,8 +1354,8 @@ void init_triton_ir(py::module &&m) {
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MaxUIOp>(lhs, rhs));
})
// maximumf follows the torch.maximum convention and returns NaN if either
// operand is NaN
// maximumf follows the torch.maximum convention and returns NaN if
// either operand is NaN
.def("create_maximumf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MaximumFOp>(lhs, rhs));
Expand Down Expand Up @@ -1844,6 +1863,12 @@ void init_triton_ir(py::module &&m) {
paddingOption);
});

if (std::string filename =
mlir::triton::tools::getStrEnv("TRITON_PASS_PLUGIN_PATH");
!filename.empty()) {
registerCustomOps(TritonOpBuilderBinding, filename);
}

py::class_<PassManager>(m, "pass_manager", py::module_local())
.def(py::init<MLIRContext *>())
.def("enable_debug",
Expand Down
2 changes: 2 additions & 0 deletions python/src/ir.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#pragma once
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectRegistry.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include <memory>
#include <string>

// A custom op builder that keeps track of the last location
class TritonOpBuilder {
Expand Down
72 changes: 72 additions & 0 deletions python/test/unit/plugins/custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch

import triton
import triton.language as tl
from triton._C.libtriton import ir
from triton.language.core import builtin
from typing import TypeVar, Type
import builtins
import os
import pathlib
from triton.compiler.code_generator import flatten_values_to_ir

T = TypeVar('T')
TensorTy = TypeVar('TensorTy')

triton.language.__all__.append("custom_op")
tensor: Type[TensorTy] = tl.tensor
builder: ir.builder

TRITON_BUILTIN = "__triton_builtin__"


def _unwrap_if_constexpr(o):
if isinstance(o, list):
return [_unwrap_if_constexpr(x) for x in o]
if isinstance(o, builtins.tuple):
return builtins.tuple(_unwrap_if_constexpr(x) for x in o)
if isinstance(o, tuple):
return tuple(_unwrap_if_constexpr(x) for x in o)
return o.value if isinstance(o, tl.constexpr) else o


DEVICE = triton.runtime.driver.active.get_active_torch_device()


@builtin
def custom_op(x, sanitize_overflow: tl.constexpr = True, _semantic=None):
x = _unwrap_if_constexpr(x)
builder = _semantic.builder
arg_handles = []
arg_handles.extend(flatten_values_to_ir([x]))
return tl.tensor(builder.create_custom_op(arg_handles), x.type)


@triton.jit
def add_kernel(
x_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
output = custom_op(x)
tl.store(output_ptr + offsets, output, mask=mask)
Comment thread
CRobeck marked this conversation as resolved.


def test_custom_ops(tmp_path: pathlib.Path):
if os.environ.get('LLVM_BUILD_SHARED_LIBS', '0') == '0':
return
size = 8
x = torch.zeros(size, device=DEVICE, dtype=torch.float32)
output_triton = torch.empty_like(x)
n_elements = output_triton.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
h = add_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=32)

src = h.asm["source"]
assert "arith.addf" in src
Loading