diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py index fc0a413c0663..0e7c67cb0328 100644 --- a/python/test/unit/tools/test_irsource.py +++ b/python/test/unit/tools/test_irsource.py @@ -1,9 +1,10 @@ import pathlib import triton -from triton.compiler import IRSource +from triton.compiler import IRSource, make_backend from triton._C.libtriton import ir target = triton.runtime.driver.active.get_current_target() +backend = make_backend(target) def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: @@ -40,7 +41,7 @@ def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir" temp_file.write_text(sample_ttgir) context = ir.context() - src = IRSource(str(temp_file), context) + src = IRSource(str(temp_file), context, backend) # check name and type signature # should match ty_to_cpp(...) @@ -85,7 +86,7 @@ def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir" temp_file.write_text(sample_ttgir_vector_add) context = ir.context() - src = IRSource(str(temp_file), context) + src = IRSource(str(temp_file), context, backend) # now test compilation triton.compile(str(temp_file), target=target) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 394a83bfb072..a76cb132ce47 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -91,12 +91,13 @@ def parse_options(self): class IRSource: - def __init__(self, path, context): + def __init__(self, path, context, backend): self.path = path path = Path(path) self.ext = path.suffix[1:] self.src = path.read_text() ir.load_dialects(context) + backend.load_dialects(context) # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now. # TODO - replace with a proper parser @@ -223,7 +224,7 @@ def compile(src, target=None, options=None): if ir_source: assert isinstance(src, str), "source must be either AST or a filepath" context = ir.context() - src = IRSource(src, context) + src = IRSource(src, context, backend) extra_options = src.parse_options() options = backend.parse_options(dict(options or dict(), **extra_options)) @@ -266,14 +267,13 @@ def compile(src, target=None, options=None): if ir_source: first_stage += 1 + # For IRSource, we have already grabbed the context + called both + # ir.load_dialects and backend.load_dialects. if not isinstance(src, IRSource): context = ir.context() ir.load_dialects(context) backend.load_dialects(context) - else: - # For IRSource, we have already grabbed the context + called ir.load_dialects - # just need to load the dialects for the backend. - backend.load_dialects(context) + codegen_fns = backend.get_codegen_implementation() module_map = backend.get_module_map() try: