diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index cd589fa920f5..339dc25e617a 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -70,7 +70,7 @@ def test_nested1_change(): def write_and_load_module(code, num_extra_lines): - with tempfile.NamedTemporaryFile(mode='w+', suffix='.py') as f: + with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as f: f.write(('# extra line\n' * num_extra_lines) + code) f.flush() spec = importlib.util.spec_from_file_location("module.name", f.name) diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 63401f28e42b..d0ecd771384f 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -34,11 +34,15 @@ def kernel_sub(a, b, o, N: tl.constexpr): def test_compile_in_subproc() -> None: + import os major, minor = torch.cuda.get_device_capability(0) cc = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(4)), (), (), ()) - multiprocessing.set_start_method('fork') + if os.name == "nt": + multiprocessing.set_start_method('spawn') + else: + multiprocessing.set_start_method('fork') proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) proc.start() proc.join() @@ -64,7 +68,7 @@ def test_compile_in_forked_subproc() -> None: capability = major * 10 + minor config = triton.compiler.AttrsDescriptor(tuple(range(1)), (), (), ()) - assert multiprocessing.get_start_method() == 'fork' + assert multiprocessing.get_start_method() in ['fork', 'spawn'] proc = multiprocessing.Process(target=compile_fn_dot, args=(config, capability)) proc.start() proc.join()