Skip to content

Commit

Permalink
[Pytest] Support CPU device
Browse files Browse the repository at this point in the history
Enable suits for cpu device.
- language
  - python/test/unit/language/test_random.py
  - python/test/unit/language/test_standard.py
- runtime
  - python/test/unit/runtime/test_bindings.py
  - python/test/unit/runtime/test_cache.py
  - python/test/unit/runtime/test_driver.py
  - python/test/unit/runtime/test_jit.py
  - python/test/unit/runtime/test_launch.py

Signed-off-by: Dmitrii Makarenko <[email protected]>
  • Loading branch information
Devjiu committed Aug 12, 2024
1 parent 2c64492 commit b730fee
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 37 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,13 @@ jobs:
python/test/unit/language/test_compile_errors.py \
python/test/unit/language/test_decorator.py \
python/test/unit/language/test_pipeliner.py \
python/test/unit/test_random.py \
python/test/unit/language/test_random.py \
python/test/unit/language/test_standard.py \
python/test/unit/runtime/test_bindings.py \
python/test/unit/runtime/test_cache.py \
python/test/unit/runtime/test_driver.py \
python/test/unit/runtime/test_jit.py \
python/test/unit/runtime/test_launch.py \
python/test/unit/cpu/test_libdevice.py \
python/test/unit/cpu/test_libmvec.py \
python/test/unit/cpu/test_opt.py
Expand Down
10 changes: 7 additions & 3 deletions python/test/unit/language/test_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import triton.language as tl

from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random
from test_core import _test_binary, int_dtypes, uint_dtypes, float_dtypes, numpy_random, is_cpu

# ---------------
# test maximum/minimum ops
Expand All @@ -17,6 +17,8 @@
def test_maximum_minium(dtype, op, device):
expr = f'tl.{op}(x, y)'
numpy_expr = f'np.{op}(x, y)'
if is_cpu() and dtype == "bfloat16":
pytest.skip('bfloat16 is not supported natively, and issueing undefined symbol: __truncsfbf2')
_test_binary(dtype, dtype, expr, numpy_expr, device=device)


Expand All @@ -26,7 +28,8 @@ def test_maximum_minium(dtype, op, device):


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize(
"M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]])
@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_sort(M, N, descending, dtype_str, device):
Expand Down Expand Up @@ -54,7 +57,8 @@ def sort_kernel(X, Z, N: tl.constexpr, M: tl.constexpr, descending: tl.constexpr


@pytest.mark.interpreter
@pytest.mark.parametrize("M, N", [[1, 512], [8, 64], [256, 16], [512, 8]])
@pytest.mark.parametrize(
"M, N", [[1, 512], [8, 64], [256, 16], [512, 8]] if not is_cpu() else [[1, 128], [8, 64], [64, 16], [128, 8]])
@pytest.mark.parametrize("dtype_str", ['int32', 'float16', 'float32', 'bfloat16'])
def test_flip(M, N, dtype_str, device):

Expand Down
8 changes: 4 additions & 4 deletions python/test/unit/runtime/test_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def add_kernel(
tl.store(out_ptr + offsets, output, mask=mask)


def test_module_walk():
def test_module_walk(device):
"""
Test the MLIR bindings exposed for the out-ot-tree walk.
"""
Expand All @@ -53,10 +53,10 @@ def walk_fn(op):

kernel = add_kernel
args = [
torch.empty((32, 32), device="cuda"), # in_ptr0
torch.empty((32, 32), device="cuda"), # in_ptr1
torch.empty((32, 32), device=device), # in_ptr0
torch.empty((32, 32), device=device), # in_ptr1
1024, # n_elements
torch.empty((32, 32), device="cuda"), # out_ptr
torch.empty((32, 32), device=device), # out_ptr
16, # BLOCK_SIZE
]
src = triton.compiler.compiler.ASTSource(
Expand Down
69 changes: 44 additions & 25 deletions python/test/unit/runtime/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def reset_tmp_dir():
shutil.rmtree(tmpdir, ignore_errors=True)


def test_reuse():
def test_reuse(device):
counter = 0

def inc_counter(*args, **kwargs):
Expand All @@ -166,14 +166,14 @@ def inc_counter(*args, **kwargs):

JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device=device)
for i in range(10):
kernel[(1, )](x, 1, BLOCK=1024)
assert counter == 1


@pytest.mark.parametrize('mode', ['enable', 'disable'])
def test_specialize(mode):
def test_specialize(mode, device):
counter = 0

def inc_counter(*args, **kwargs):
Expand All @@ -182,23 +182,26 @@ def inc_counter(*args, **kwargs):

JITFunction.cache_hook = inc_counter
reset_tmp_dir()
x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device=device)
function = {'enable': kernel, 'disable': kernel_nospec}[mode]
target = {'enable': 3, 'disable': 1}[mode]
for i in [1, 2, 4, 8, 16, 32]:
function[(1, )](x, i, BLOCK=512)
assert counter == target


def test_annotation():
def test_annotation(device):

@triton.jit
def kernel(X, i: tl.int32):
tl.store(X, i)

x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device=device)

device = torch.cuda.current_device()
if device == "cuda":
device = torch.cuda.current_device()
elif device == "cpu":
device = torch.cpu.current_device() + ":0"
kernel[(1, )](x, 1)
kernel[(1, )](x, 8)
kernel[(1, )](x, 16)
Expand All @@ -209,14 +212,14 @@ def kernel(X, i: tl.int32):
GLOBAL_DEFAULT_ARG = 1


def test_kernel_default_arg():
def test_kernel_default_arg(device):
global GLOBAL_DEFAULT_ARG

@triton.jit
def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
tl.store(X, i)

x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device=device)
kernel[(1, )](x)
assert x == torch.ones_like(x)

Expand All @@ -226,21 +229,24 @@ def kernel(X, i: tl.constexpr = GLOBAL_DEFAULT_ARG):
kernel[(1, )](x)
assert x == torch.ones_like(x)

device = torch.cuda.current_device()
if device == "cuda":
device = torch.cuda.current_device()
elif device == "cpu":
device = torch.cpu.current_device() + ":0"
assert len(kernel.cache[device]) == 1


GLOBAL_VAR: tl.constexpr = 1


def test_kernel_global_var_change():
def test_kernel_global_var_change(device):
global GLOBAL_VAR

@triton.jit
def kernel(X):
tl.store(X, GLOBAL_VAR)

x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device=device)
kernel[(1, )](x)
assert x == torch.ones_like(x)

Expand Down Expand Up @@ -385,13 +391,13 @@ def kernel():
assert not kernel.used_global_vals


def test_constexpr_not_callable() -> None:
def test_constexpr_not_callable(device) -> None:

@triton.jit
def kernel(X, c: tl.constexpr):
tl.store(X, 2)

x = torch.empty(1, dtype=torch.int32, device='cuda')
x = torch.empty(1, dtype=torch.int32, device=device)
error = False
try:
kernel[(1, )](x, c="str")
Expand All @@ -406,20 +412,23 @@ def kernel(X, c: tl.constexpr):
assert error is True


def test_jit_warmup_cache() -> None:
def test_jit_warmup_cache(device) -> None:

@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

args = [
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device="cuda"),
torch.randn(32, dtype=torch.float32, device=device),
torch.randn(32, dtype=torch.float32, device=device),
torch.randn(32, dtype=torch.float32, device=device),
32,
]
device = torch.cuda.current_device()
if device == "cuda":
device = torch.cuda.current_device()
elif device == "cpu":
device = torch.cpu.current_device() + ":0"
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
Expand All @@ -429,15 +438,19 @@ def kernel_add(a, b, o, N: tl.constexpr):
assert len(kernel_add.cache[device]) == 1


def test_jit_debug() -> None:
def test_jit_debug(device) -> None:

@triton.jit
def kernel_add(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))

device = torch.cuda.current_device()
if device == "cuda":
device = torch.cuda.current_device()
elif device == "cpu":
device = torch.cpu.current_device() + ":0"
pytest.skip('Debug is not yet supported on CPU')
assert len(kernel_add.cache[device]) == 0
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add.cache[device]) == 1
Expand All @@ -457,13 +470,16 @@ def add_fn(a, b, o, N: tl.constexpr):
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))


def test_jit_noinline() -> None:
def test_jit_noinline(device) -> None:

@triton.jit
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)

device = torch.cuda.current_device()
if device == "cuda":
device = torch.cuda.current_device()
elif device == "cpu":
device = torch.cpu.current_device() + ":0"
assert len(kernel_add_device.cache[device]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, ))
assert len(kernel_add_device.cache[device]) == 1
Expand Down Expand Up @@ -493,7 +509,7 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp0, xmask)


def test_preload() -> None:
def test_preload(device) -> None:

@triton.jit
def kernel_add(a, b, o, N: tl.constexpr, type: tl.constexpr):
Expand All @@ -507,7 +523,10 @@ def kernel_sub(a, b, o, N: tl.constexpr, type: tl.constexpr):
tl.device_assert(idx < 32, "idx < 32")
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx))

device = torch.cuda.current_device()
if device == "cuda":
device = torch.cuda.current_device()
elif device == "cpu":
device = torch.cpu.current_device() + ":0"

# get the serialized specialization data
specialization_data = None
Expand Down
8 changes: 5 additions & 3 deletions python/test/unit/runtime/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def kernel(x):
assert used_hook


def test_memory_leak() -> None:
def test_memory_leak(device) -> None:
if device is None:
device = 'cuda'

@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):
Expand All @@ -57,8 +59,8 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr):

tracemalloc.start()
try:
inp = torch.randn(10, device='cuda')
out = torch.randn(10, device='cuda')
inp = torch.randn(10, device=device)
out = torch.randn(10, device=device)
kernel[(10, )](inp, out, 10, XBLOCK=16)
gc.collect()
begin, _ = tracemalloc.get_traced_memory()
Expand Down
4 changes: 3 additions & 1 deletion python/triton/runtime/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,8 @@ def preload(self, specialization_data):
import json
import triton.language as tl
device = driver.active.get_current_device()
target = driver.active.get_current_target()
device_key = f"{target.backend}:{device}"
deserialized_obj = json.loads(specialization_data)
if deserialized_obj['name'] != self.fn.__name__:
raise RuntimeError(
Expand All @@ -767,7 +769,7 @@ def preload(self, specialization_data):
}
key = deserialized_obj['key']
kernel = compile(src, None, options)
self.cache[device][key] = kernel
self.cache[device_key][key] = kernel
return kernel

# we do not parse `src` in the constructor because
Expand Down

0 comments on commit b730fee

Please sign in to comment.