Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,11 @@ def walk_fn(op):
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())
codegen_fns = dict()
module_map = backend.get_module_map()
triton._C.libtriton.ir.load_dialects(context)
backend.load_dialects(context)

ttir_module = src.make_ir(options, codegen_fns, context)
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
ttir_module.walk(walk_fn)


Expand Down
10 changes: 9 additions & 1 deletion python/triton/backends/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

from abc import ABCMeta, abstractmethod, abstractclassmethod
from dataclasses import dataclass
from typing import Union
from typing import Dict, Union
from types import ModuleType


@dataclass(frozen=True)
Expand Down Expand Up @@ -74,3 +75,10 @@ def load_dialects(self, context):
Load additional MLIR dialects into the provided `context`
"""
raise NotImplementedError

@abstractmethod
def get_module_map(self) -> Dict[str, ModuleType]:
"""
Return a map of interface modules to their device-specific implementations.
"""
raise NotImplementedError
26 changes: 20 additions & 6 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def visit_Call(self, node: ast.Call) -> bool:
class CodeGenerator(ast.NodeVisitor):

def __init__(self, context, prototype, gscope, attributes, constants, function_name, jit_fn: JITFunction, options,
codegen_fns, debug=None, module=None, is_kernel=False, function_types: Optional[Dict] = None,
noinline=False, file_name: Optional[str] = None, begin_line=0):
codegen_fns, module_map, debug=None, module=None, is_kernel=False,
function_types: Optional[Dict] = None, noinline=False, file_name: Optional[str] = None, begin_line=0):
self.context = context
self.builder = ir.builder(context)
self.file_name = file_name
Expand All @@ -201,10 +201,23 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n
# Convert custom types not natively supported on HW.
# convert_custom_types(intput_tensor, dtype, fp_downcast_rounding=None, _builder=None)
self.builder.codegen_fns = codegen_fns
self.builder.module_map = {} if module_map is None else module_map
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = {} if function_types is None else function_types
self.prototype = prototype
self.gscope = gscope

self.gscope = {}
for k, v in gscope.items():
if isinstance(v, ModuleType):
self.gscope[k] = module_map.get(v.__name__, v)
continue

module_name = getattr(v, "__module__", "")
if module_name in module_map:
self.gscope[k] = getattr(module_map[module_name], k)
else:
self.gscope[k] = v

self.lscope = {}
self.attributes = attributes
self.constants = constants
Expand Down Expand Up @@ -1049,7 +1062,8 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs):
generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module,
jit_fn=fn, function_name=fn_name, function_types=self.function_ret_types,
noinline=fn.noinline, file_name=file_name, begin_line=begin_line,
options=self.builder.options, codegen_fns=self.builder.codegen_fns, debug=debug)
options=self.builder.options, codegen_fns=self.builder.codegen_fns,
module_map=self.builder.module_map, debug=debug)
try:
generator.visit(fn.parse())
except Exception as e:
Expand Down Expand Up @@ -1252,7 +1266,7 @@ def kernel_suffix(signature, specialization):
return suffix


def ast_to_ttir(fn, specialization, context, options, codegen_fns):
def ast_to_ttir(fn, specialization, context, options, codegen_fns, module_map):
attrs = specialization.attrs
# create kernel prototype
cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i
Expand All @@ -1272,7 +1286,7 @@ def ast_to_ttir(fn, specialization, context, options, codegen_fns):
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name,
jit_fn=fn, attributes=new_attrs, is_kernel=True, file_name=file_name,
begin_line=begin_line, options=options, codegen_fns=codegen_fns)
begin_line=begin_line, options=options, codegen_fns=codegen_fns, module_map=module_map)
generator.visit(fn.parse())

ret = generator.module
Expand Down
10 changes: 6 additions & 4 deletions python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def hash(self):
key = f"{self.fn.cache_key}-{self.attrs.hash()}-{sorted_sig}-{sorted_constants}"
return hashlib.sha256(key.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
def make_ir(self, options, codegen_fns, module_map, context):
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
module_map=module_map)

def parse_options(self):
return dict()
Expand All @@ -132,7 +133,7 @@ def __init__(self, path):
def hash(self):
return hashlib.sha256(self.src.encode("utf-8")).hexdigest()

def make_ir(self, options, codegen_fns, context):
def make_ir(self, options, codegen_fns, module_map, context):
module = ir.parse_mlir_module(self.path, context)
module.context = context
return module
Expand Down Expand Up @@ -277,8 +278,9 @@ def compile(src, target=None, options=None):
ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()
module_map = backend.get_module_map()
try:
module = src.make_ir(options, codegen_fns, context)
module = src.make_ir(options, codegen_fns, module_map, context)
except Exception as e:
filter_traceback(e)
raise
Expand Down
Loading