diff --git a/docs/source/user/cudapysupported.rst b/docs/source/user/cudapysupported.rst index dca1539f3..cf70b7686 100644 --- a/docs/source/user/cudapysupported.rst +++ b/docs/source/user/cudapysupported.rst @@ -225,6 +225,7 @@ The following functions from the :mod:`math` module are supported: * :func:`math.log2` * :func:`math.log10` * :func:`math.log1p` +* :func:`math.nextafter` (Excluding the ``steps`` keyword argument) * :func:`math.sqrt` * :func:`math.remainder` * :func:`math.pow` diff --git a/numba_cuda/numba/cuda/cudamath.py b/numba_cuda/numba/cuda/cudamath.py index 6af2fb8b9..4fad82e11 100644 --- a/numba_cuda/numba/cuda/cudamath.py +++ b/numba_cuda/numba/cuda/cudamath.py @@ -87,6 +87,7 @@ class Math_hypot(ConcreteTemplate): @infer_global(math.copysign) @infer_global(math.fmod) +@infer_global(math.nextafter) class Math_binary(ConcreteTemplate): cases = [ signature(types.float32, types.float32, types.float32), diff --git a/numba_cuda/numba/cuda/mathimpl.py b/numba_cuda/numba/cuda/mathimpl.py index c20e040fe..0b5803e98 100644 --- a/numba_cuda/numba/cuda/mathimpl.py +++ b/numba_cuda/numba/cuda/mathimpl.py @@ -67,6 +67,7 @@ binarys += [("fmod", "fmodf", math.fmod)] binarys += [("hypot", "hypotf", math.hypot)] binarys += [("remainder", "remainderf", math.remainder)] +binarys += [("nextafter", "nextafterf", math.nextafter)] binarys_fastmath = {} binarys_fastmath["powf"] = "fast_powf" diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py index e73ee690b..9b8aa9d6d 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py @@ -7,7 +7,7 @@ from numba import cuda from numba.cuda import float32 from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx -from math import cos, sin, tan, exp, log, log10, log2, pow, tanh +from math import cos, sin, tan, exp, log, log10, log2, pow, tanh, nextafter from operator import truediv import numpy as np from numba.cuda.testing import CUDATestCase, skip_on_cudasim, skip_unless_cc_75 @@ -194,6 +194,15 @@ def test_powf(self): ), ) + def test_nextafterf(self): + self._test_fast_math_binary( + nextafter, + FastMathCriterion( + fast_expected=[".ftz.f32 "], + prec_unexpected=[".ftz.f32 "], + ), + ) + def test_divf(self): self._test_fast_math_binary( truediv, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_math.py b/numba_cuda/numba/cuda/tests/cudapy/test_math.py index c2771d6b3..ac7afea01 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_math.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_math.py @@ -146,6 +146,11 @@ def math_remainder(A, B, C): C[i] = math.remainder(A[i], B[i]) +def math_nextafter(A, B, C): + i = cuda.grid(1) + C[i] = math.nextafter(A[i], B[i]) + + def math_sqrt(A, B): i = cuda.grid(1) B[i] = math.sqrt(A[i]) @@ -614,6 +619,15 @@ def test_0_0(r, x, y): test_0_0[1, 1](r, 0, 0) self.assertTrue(np.isnan(r[0])) + # --------------------------------------------------------------------------- + # test_math_nextafter + + def test_math_nextafter(self): + self.binary_template_float32(math_nextafter, np.nextafter, start=1e-11) + self.binary_template_float64(math_remainder, np.remainder, start=1e-11) + self.binary_template_int64(math_remainder, np.remainder, start=1) + self.binary_template_uint64(math_remainder, np.remainder, start=1) + # --------------------------------------------------------------------------- # test_math_sqrt