From 363c82de1becc38d7e29d3e1d79923f51f0b664c Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Thu, 23 Oct 2025 00:47:13 +0100 Subject: [PATCH 1/4] feat: add math nextafter --- docs/source/user/cudapysupported.rst | 1 + numba_cuda/numba/cuda/cudamath.py | 1 + numba_cuda/numba/cuda/mathimpl.py | 2 ++ .../numba/cuda/tests/cudapy/test_fastmath.py | 11 ++++++++++- numba_cuda/numba/cuda/tests/cudapy/test_math.py | 14 ++++++++++++++ 5 files changed, 28 insertions(+), 1 deletion(-) diff --git a/docs/source/user/cudapysupported.rst b/docs/source/user/cudapysupported.rst index a62205cf6..ee0fa6743 100644 --- a/docs/source/user/cudapysupported.rst +++ b/docs/source/user/cudapysupported.rst @@ -224,6 +224,7 @@ The following functions from the :mod:`math` module are supported: * :func:`math.log2` * :func:`math.log10` * :func:`math.log1p` +* :func:`math.nextafter` * :func:`math.sqrt` * :func:`math.remainder` * :func:`math.pow` diff --git a/numba_cuda/numba/cuda/cudamath.py b/numba_cuda/numba/cuda/cudamath.py index 1cc5a5ac8..0733459bc 100644 --- a/numba_cuda/numba/cuda/cudamath.py +++ b/numba_cuda/numba/cuda/cudamath.py @@ -82,6 +82,7 @@ class Math_hypot(ConcreteTemplate): @infer_global(math.copysign) @infer_global(math.fmod) +@infer_global(math.nextafter) class Math_binary(ConcreteTemplate): cases = [ signature(types.float32, types.float32, types.float32), diff --git a/numba_cuda/numba/cuda/mathimpl.py b/numba_cuda/numba/cuda/mathimpl.py index 5f3e634b8..05c1a5503 100644 --- a/numba_cuda/numba/cuda/mathimpl.py +++ b/numba_cuda/numba/cuda/mathimpl.py @@ -64,9 +64,11 @@ binarys += [("fmod", "fmodf", math.fmod)] binarys += [("hypot", "hypotf", math.hypot)] binarys += [("remainder", "remainderf", math.remainder)] +binarys += [("nextafter", "nextafterf", math.nextafter)] binarys_fastmath = {} binarys_fastmath["powf"] = "fast_powf" +binarys_fastmath["nextafterf"] = "fast_nextafterf" @lower(math.isinf, types.Integer) diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py index c1c58dffa..7d8411079 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from numba import cuda, float32 from numba.cuda.compiler import compile_ptx_for_current_device, compile_ptx -from math import cos, sin, tan, exp, log, log10, log2, pow, tanh +from math import cos, sin, tan, exp, log, log10, log2, pow, tanh, nextafter from operator import truediv import numpy as np from numba.cuda.testing import CUDATestCase, skip_on_cudasim, skip_unless_cc_75 @@ -179,6 +179,15 @@ def test_powf(self): ), ) + def test_nextafterf(self): + self._test_fast_math_binary( + nextafter, + FastMathCriterion( + fast_expected=["lg2.approx.ftz.f32 "], # FIX + prec_unexpected=["lg2.approx.ftz.f32 "], # FIX + ), + ) + def test_divf(self): self._test_fast_math_binary( truediv, diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_math.py b/numba_cuda/numba/cuda/tests/cudapy/test_math.py index 57f0e3d97..82f4a7aed 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_math.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_math.py @@ -138,6 +138,11 @@ def math_remainder(A, B, C): C[i] = math.remainder(A[i], B[i]) +def math_nextafter(A, B, C): + i = cuda.grid(1) + C[i] = math.nextafter(A[i], B[i]) + + def math_sqrt(A, B): i = cuda.grid(1) B[i] = math.sqrt(A[i]) @@ -594,6 +599,15 @@ def test_0_0(r, x, y): test_0_0[1, 1](r, 0, 0) self.assertTrue(np.isnan(r[0])) + # --------------------------------------------------------------------------- + # test_math_nextafter + + def test_math_nextafter(self): + self.binary_template_float32(math_nextafter, np.nextafter, start=1e-11) + self.binary_template_float64(math_remainder, np.remainder, start=1e-11) + self.binary_template_int64(math_remainder, np.remainder, start=1) + self.binary_template_uint64(math_remainder, np.remainder, start=1) + # --------------------------------------------------------------------------- # test_math_sqrt From 8122537349f4ee8d4e702c441e5e25faed471860 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Sun, 26 Oct 2025 19:38:52 +0000 Subject: [PATCH 2/4] fix: test --- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py index 7d8411079..5888ebe49 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py @@ -183,8 +183,8 @@ def test_nextafterf(self): self._test_fast_math_binary( nextafter, FastMathCriterion( - fast_expected=["lg2.approx.ftz.f32 "], # FIX - prec_unexpected=["lg2.approx.ftz.f32 "], # FIX + fast_expected=[".ftz.f32 "], + prec_unexpected=[".ftz.f32 "], ), ) From deb38eaf2ecc3beee506872a51503cad19891f9b Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Sun, 26 Oct 2025 19:42:00 +0000 Subject: [PATCH 3/4] chore: remove fast_nextafter --- numba_cuda/numba/cuda/mathimpl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/numba_cuda/numba/cuda/mathimpl.py b/numba_cuda/numba/cuda/mathimpl.py index 05c1a5503..c89a06e44 100644 --- a/numba_cuda/numba/cuda/mathimpl.py +++ b/numba_cuda/numba/cuda/mathimpl.py @@ -68,7 +68,6 @@ binarys_fastmath = {} binarys_fastmath["powf"] = "fast_powf" -binarys_fastmath["nextafterf"] = "fast_nextafterf" @lower(math.isinf, types.Integer) From f540109ca1bd441105306c648c4f417f2637d757 Mon Sep 17 00:00:00 2001 From: Kaeun Kim Date: Mon, 3 Nov 2025 23:19:02 +0000 Subject: [PATCH 4/4] chore: add explanation regarding steps kwarg --- docs/source/user/cudapysupported.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user/cudapysupported.rst b/docs/source/user/cudapysupported.rst index ee0fa6743..4c590f554 100644 --- a/docs/source/user/cudapysupported.rst +++ b/docs/source/user/cudapysupported.rst @@ -224,7 +224,7 @@ The following functions from the :mod:`math` module are supported: * :func:`math.log2` * :func:`math.log10` * :func:`math.log1p` -* :func:`math.nextafter` +* :func:`math.nextafter` (Excluding the ``steps`` keyword argument) * :func:`math.sqrt` * :func:`math.remainder` * :func:`math.pow`