diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 38d396cf0a1f..02b9185eab39 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -170,31 +170,34 @@ def kernel_add(a, b, o, N: tl.constexpr): assert bins[0].asm['ttir'] != bins[1].asm['ttir'] -def test_compile_in_subproc() -> None: +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) + + +def compile_fn(config, cc): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) - tl.load(b + idx) * 777) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( + fn=kernel_sub, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + device=0, + constants={3: 32}, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + +def test_compile_in_subproc() -> None: major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = namedtuple("instance_descriptor", [ - "divisible_by_16", "equal_to_1"])( - tuple(range(4)), - ()) + config = instance_descriptor(tuple(range(4)), ()) + multiprocessing.set_start_method('spawn') proc = multiprocessing.Process( - target=triton.compile, - kwargs=dict( - fn=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, - constants={3: 32}, - configs=[config], - warm_cache_only=True, - cc=cc, - )) + target=compile_fn, + args=(config, cc)) proc.start() proc.join() assert proc.exitcode == 0