Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 20 additions & 22 deletions python/test/unit/tools/test_irsource.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import tempfile
import pathlib
import triton
from triton.compiler import IRSource
from triton._C.libtriton import ir

target = triton.runtime.driver.active.get_current_target()


def test_mlir_attribute_parsing() -> None:
def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None:
'''
Tests that MLIR attributes are parsed correctly from input ttir/ttgir.

Expand Down Expand Up @@ -37,21 +37,20 @@ def test_mlir_attribute_parsing() -> None:
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir)
f.flush()
context = ir.context()
src = IRSource(f.name, context)
temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir"
temp_file.write_text(sample_ttgir)
context = ir.context()
src = IRSource(str(temp_file), context)

# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"
# check name and type signature
# should match ty_to_cpp(...)
assert src.signature == \
{0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \
4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"}
assert src.name == "@matmul_kernel"

# check num warps
assert src.parse_options()['num_warps'] == 8
# check num warps
assert src.parse_options()['num_warps'] == 8

sample_ttgir_vector_add = r"""
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
Expand Down Expand Up @@ -83,11 +82,10 @@ def test_mlir_attribute_parsing() -> None:
}
}
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(sample_ttgir_vector_add)
f.flush()
context = ir.context()
src = IRSource(f.name, context)
temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir"
temp_file.write_text(sample_ttgir_vector_add)
context = ir.context()
src = IRSource(str(temp_file), context)

# now test compilation
triton.compile(f.name, target=target)
# now test compilation
triton.compile(str(temp_file), target=target)