diff --git a/python/test/unit/tools/test_irsource.py b/python/test/unit/tools/test_irsource.py index a886ebb457f4..fc0a413c0663 100644 --- a/python/test/unit/tools/test_irsource.py +++ b/python/test/unit/tools/test_irsource.py @@ -1,4 +1,4 @@ -import tempfile +import pathlib import triton from triton.compiler import IRSource from triton._C.libtriton import ir @@ -6,7 +6,7 @@ target = triton.runtime.driver.active.get_current_target() -def test_mlir_attribute_parsing() -> None: +def test_mlir_attribute_parsing(tmp_path: pathlib.Path) -> None: ''' Tests that MLIR attributes are parsed correctly from input ttir/ttgir. @@ -37,21 +37,20 @@ def test_mlir_attribute_parsing() -> None: } } """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(sample_ttgir) - f.flush() - context = ir.context() - src = IRSource(f.name, context) + temp_file = tmp_path / "test_mlir_attribute_parsing0.ttgir" + temp_file.write_text(sample_ttgir) + context = ir.context() + src = IRSource(str(temp_file), context) - # check name and type signature - # should match ty_to_cpp(...) - assert src.signature == \ - {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ - 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} - assert src.name == "@matmul_kernel" + # check name and type signature + # should match ty_to_cpp(...) + assert src.signature == \ + {0: "*f32", 1: "*f32", 2: "*f32", 3: "i32", \ + 4: "i32", 5: "i32", 6: "i32", 7: "i32", 8: "nvTmaDesc", 9: "nvTmaDesc"} + assert src.name == "@matmul_kernel" - # check num warps - assert src.parse_options()['num_warps'] == 8 + # check num warps + assert src.parse_options()['num_warps'] == 8 sample_ttgir_vector_add = r""" #blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> @@ -83,11 +82,10 @@ def test_mlir_attribute_parsing() -> None: } } """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(sample_ttgir_vector_add) - f.flush() - context = ir.context() - src = IRSource(f.name, context) + temp_file = tmp_path / "test_mlir_attribute_parsing1.ttgir" + temp_file.write_text(sample_ttgir_vector_add) + context = ir.context() + src = IRSource(str(temp_file), context) - # now test compilation - triton.compile(f.name, target=target) + # now test compilation + triton.compile(str(temp_file), target=target)