Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import multiprocessing
import os
import shutil
from collections import namedtuple

import pytest
import torch
Expand Down Expand Up @@ -198,39 +196,6 @@ def kernel_add_device(a, b, o, N: tl.constexpr):
assert inline_ttir != noinline_ttir


instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])


def compile_fn(config, cc):
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
)


def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), ())

multiprocessing.set_start_method('spawn')
proc = multiprocessing.Process(
target=compile_fn,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0


def test_memory_leak() -> None:
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
Expand Down
83 changes: 83 additions & 0 deletions python/test/unit/runtime/test_subproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import multiprocessing
import os
import shutil
from collections import namedtuple

import torch

import triton
import triton.language as tl

tmpdir = ".tmp"


def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)


instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])


def compile_fn(config, cc):
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
)


def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), ())

multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(
target=compile_fn,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0


def compile_fn_dot(config, cc):
@triton.jit
def kernel_dot(Z):
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
z = tl.load(Z + offs)
z = tl.dot(z, z)
tl.store(Z + offs, z)

triton.compile(
fn=kernel_dot,
signature={0: "*fp32"},
device=0,
configs=[config],
warm_cache_only=True,
cc=cc,
)


def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), ())

assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(
target=compile_fn_dot,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0
13 changes: 0 additions & 13 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar

import triton
from . import core as tl
from triton._C.libtriton.triton import ir

Expand Down Expand Up @@ -1181,18 +1180,6 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if capability < 70:
assert (
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
Expand Down