Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/test/unit/language/test_libdevice.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,45 @@
import pytest
import torch

import triton
import triton.language as tl

from triton.language.extra import libdevice
from triton.language.extra.libdevice import fast_dividef as my_fast_dividef


@pytest.mark.parametrize("dtype_str", ["float32", "float64"])
@pytest.mark.parametrize(
"libdevice_fn, torch_special_fn",
[
("j0", "bessel_j0"),
("j1", "bessel_j1"),
("y0", "bessel_y0"),
("y1", "bessel_y1"),
("cyl_bessel_i0", "i0"),
("cyl_bessel_i1", "i1"),
],
)
def test_bessel(dtype_str, libdevice_fn, torch_special_fn, device):
SIZE = 128
dtype = getattr(torch, dtype_str)

x = torch.randn((SIZE, ), dtype=dtype, device=device)
y_exp = torch.empty((SIZE, ), dtype=dtype, device=device)
y_ref = getattr(torch.special, torch_special_fn)(x)

@triton.jit
def kernel(in_p, out_p, fn: tl.constexpr, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(in_p + off)
res = getattr(libdevice, fn)(x)
tl.store(out_p + off, res)

kernel[(1, )](x, y_exp, fn=libdevice_fn, SIZE=SIZE, num_warps=4, num_ctas=1)

torch.testing.assert_close(y_ref, y_exp, equal_nan=True)


def test_libdevice_rename(device):
# mark the import as used by this test
_ = my_fast_dividef
Expand Down
Loading