diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index cf4de3ce063d..2bca3f28379d 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,7 +1,5 @@ -import multiprocessing import os import shutil -from collections import namedtuple import pytest import torch @@ -198,39 +196,6 @@ def kernel_add_device(a, b, o, N: tl.constexpr): assert inline_ttir != noinline_ttir -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) - - -def compile_fn(config, cc): - @triton.jit - def kernel_sub(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) - triton.compile( - fn=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, - constants={3: 32}, - configs=[config], - warm_cache_only=True, - cc=cc, - ) - - -def test_compile_in_subproc() -> None: - major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor - config = instance_descriptor(tuple(range(4)), ()) - - multiprocessing.set_start_method('spawn') - proc = multiprocessing.Process( - target=compile_fn, - args=(config, cc)) - proc.start() - proc.join() - assert proc.exitcode == 0 - - def test_memory_leak() -> None: @triton.jit def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000000..0e0d33c6fd21 --- /dev/null +++ b/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,83 @@ +import multiprocessing +import os +import shutil +from collections import namedtuple + +import torch + +import triton +import triton.language as tl + +tmpdir = ".tmp" + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + + +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) + + +def compile_fn(config, cc): + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( + fn=kernel_sub, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + device=0, + constants={3: 32}, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + + +def test_compile_in_subproc() -> None: + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = instance_descriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process( + target=compile_fn, + args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(config, cc): + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + triton.compile( + fn=kernel_dot, + signature={0: "*fp32"}, + device=0, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + + +def test_compile_in_forked_subproc() -> None: + reset_tmp_dir() + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = instance_descriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process( + target=compile_fn_dot, + args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 678ee8cc780e..b4a55c0fa9a8 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -3,7 +3,6 @@ from functools import wraps from typing import List, Optional, Sequence, Tuple, TypeVar -import triton from . import core as tl from triton._C.libtriton.triton import ir @@ -1181,18 +1180,6 @@ def dot(lhs: tl.tensor, allow_tf32: bool, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - try: - import torch - except ImportError: - raise ImportError("Triton requires PyTorch to be installed") - if torch.version.hip is None: - device = triton.runtime.jit.get_current_device() - capability = triton.runtime.jit.get_device_capability(device) - capability = capability[0] * 10 + capability[1] - if capability < 70: - assert ( - not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8() - ), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)" assert lhs.type.is_block() and rhs.type.is_block() assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!" assert len(lhs.shape) == 2 and len(rhs.shape) == 2