From a8f4a9dd4013f5ffabb99a40d9ee399ee46fc485 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Apr 2023 19:45:23 -0400 Subject: [PATCH 1/2] Update --- python/test/unit/runtime/test_cache.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 38d396cf0a1f..45414c1b46d4 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -170,6 +170,9 @@ def kernel_add(a, b, o, N: tl.constexpr): assert bins[0].asm['ttir'] != bins[1].asm['ttir'] +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) + + def test_compile_in_subproc() -> None: @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): @@ -179,10 +182,7 @@ def kernel_sub(a, b, o, N: tl.constexpr): major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor - config = namedtuple("instance_descriptor", [ - "divisible_by_16", "equal_to_1"])( - tuple(range(4)), - ()) + config = instance_descriptor(tuple(range(4)), ()) proc = multiprocessing.Process( target=triton.compile, From dd38a01e8e8f82967bde7be7bc2ae94b77015224 Mon Sep 17 00:00:00 2001 From: Jokeren Date: Thu, 27 Apr 2023 22:28:17 -0400 Subject: [PATCH 2/2] Update --- python/test/unit/runtime/test_cache.py | 29 ++++++++++++++------------ 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 45414c1b46d4..02b9185eab39 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -173,28 +173,31 @@ def kernel_add(a, b, o, N: tl.constexpr): instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) -def test_compile_in_subproc() -> None: +def compile_fn(config, cc): @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) - tl.load(b + idx) * 777) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( + fn=kernel_sub, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + device=0, + constants={3: 32}, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + +def test_compile_in_subproc() -> None: major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor config = instance_descriptor(tuple(range(4)), ()) + multiprocessing.set_start_method('spawn') proc = multiprocessing.Process( - target=triton.compile, - kwargs=dict( - fn=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, - constants={3: 32}, - configs=[config], - warm_cache_only=True, - cc=cc, - )) + target=compile_fn, + args=(config, cc)) proc.start() proc.join() assert proc.exitcode == 0