diff --git a/numba_cuda/numba/cuda/np/arrayobj.py b/numba_cuda/numba/cuda/np/arrayobj.py index 66c69920c..758851265 100644 --- a/numba_cuda/numba/cuda/np/arrayobj.py +++ b/numba_cuda/numba/cuda/np/arrayobj.py @@ -5504,37 +5504,31 @@ def _array_copy(context, builder, sig, args): dest_data = ret.data assert rettype.layout in "CF" - if arytype.layout == rettype.layout: - # Fast path: memcpy - cgutils.raw_memcpy( - builder, dest_data, src_data, ary.nitems, ary.itemsize, align=1 - ) - else: - src_strides = cgutils.unpack_tuple(builder, ary.strides) - dest_strides = cgutils.unpack_tuple(builder, ret.strides) - intp_t = context.get_value_type(types.intp) + src_strides = cgutils.unpack_tuple(builder, ary.strides) + dest_strides = cgutils.unpack_tuple(builder, ret.strides) + intp_t = context.get_value_type(types.intp) - with cgutils.loop_nest(builder, shapes, intp_t) as indices: - src_ptr = cgutils.get_item_pointer2( - context, - builder, - src_data, - shapes, - src_strides, - arytype.layout, - indices, - ) - dest_ptr = cgutils.get_item_pointer2( - context, - builder, - dest_data, - shapes, - dest_strides, - rettype.layout, - indices, - ) - builder.store(builder.load(src_ptr), dest_ptr) + with cgutils.loop_nest(builder, shapes, intp_t) as indices: + src_ptr = cgutils.get_item_pointer2( + context, + builder, + src_data, + shapes, + src_strides, + arytype.layout, + indices, + ) + dest_ptr = cgutils.get_item_pointer2( + context, + builder, + dest_data, + shapes, + dest_strides, + rettype.layout, + indices, + ) + builder.store(builder.load(src_ptr), dest_ptr) return impl_ret_new_ref(context, builder, sig.return_type, ret._getvalue()) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 1f2ac98c9..3aa706089 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -275,6 +275,34 @@ def mangler(self, name, argtypes, *, abi_tags=(), uid=None): name, argtypes, abi_tags=abi_tags, uid=uid ) + def make_constant_list(self, builder, listty, lst): + import numpy as np + + constvals = [ + self.get_constant(listty.dtype, i) for i in iter(np.array(lst)) + ] + instance = self.build_list(builder, listty, constvals) + # create constant address space version of the list + lmod = builder.module + + constlistty = instance.type + constlist = ir.Constant(constlistty, instance) + addrspace = nvvm.ADDRSPACE_CONSTANT + gv = cgutils.add_global_variable( + lmod, constlist.type, "_cudapy_clist", addrspace=addrspace + ) + gv.linkage = "internal" + gv.global_constant = True + gv.initializer = constlist + + # Convert to generic address-space + ptrty = ir.PointerType(constlistty) + genptr = builder.addrspacecast(gv, ptrty, "generic") + lst = cgutils.create_struct_proxy(listty)( + self, builder, value=builder.load(genptr) + ) + return lst._getvalue() + def make_constant_array(self, builder, aryty, arr): """ Unlike the parent version. This returns a a pointer in the constant diff --git a/numba_cuda/numba/cuda/testing.py b/numba_cuda/numba/cuda/testing.py index 7906af27b..42b366c57 100644 --- a/numba_cuda/numba/cuda/testing.py +++ b/numba_cuda/numba/cuda/testing.py @@ -185,6 +185,17 @@ def assertFileCheckMatches( ) +class NRTEnablingCUDATestCase(CUDATestCase): + def setUp(self): + self.old_nrt_setting = config.CUDA_ENABLE_NRT + config.CUDA_ENABLE_NRT = True + super().setUp() + + def tearDown(self): + config.CUDA_ENABLE_NRT = self.old_nrt_setting + super().tearDown() + + def skip_on_cudasim(reason): """Skip this test if running on the CUDA simulator""" assert isinstance(reason, str) diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py b/numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py index f30d42de7..b0897c6e4 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py @@ -3,9 +3,8 @@ import numpy as np from numba import cuda -from numba.cuda.testing import CUDATestCase +from numba.cuda.testing import NRTEnablingCUDATestCase import unittest -from numba.cuda import config def reinterpret_array_type(byte_arr, start, stop, output): @@ -14,16 +13,7 @@ def reinterpret_array_type(byte_arr, start, stop, output): output[0] = val -class TestCudaArrayMethods(CUDATestCase): - def setUp(self): - self.old_nrt_setting = config.CUDA_ENABLE_NRT - config.CUDA_ENABLE_NRT = True - super(TestCudaArrayMethods, self).setUp() - - def tearDown(self): - config.CUDA_ENABLE_NRT = self.old_nrt_setting - super(TestCudaArrayMethods, self).tearDown() - +class TestCudaArrayMethods(NRTEnablingCUDATestCase): def test_reinterpret_array_type(self): """ Reinterpret byte array as int32 in the GPU. diff --git a/numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py b/numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py index dddcbfd4b..e3c352ad9 100644 --- a/numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py +++ b/numba_cuda/numba/cuda/tests/cudapy/test_array_reductions.py @@ -2,15 +2,23 @@ # SPDX-License-Identifier: BSD-2-Clause import numpy as np -from numba.cuda.tests.support import TestCase, MemoryLeakMixin +from numba.tests.support import MemoryLeakMixin +from numba.cuda.testing import NRTEnablingCUDATestCase from numba import cuda + +from itertools import combinations_with_replacement from numba.cuda.testing import skip_on_cudasim, skip_on_nvjitlink_13_1_sm_120 from numba.cuda.misc.special import literal_unroll from numba.cuda import config +import unittest + + +def array_median_global(arr): + return np.median(arr) @skip_on_cudasim("doesn't work in the simulator") -class TestArrayReductions(MemoryLeakMixin, TestCase): +class TestArrayReductions(MemoryLeakMixin, NRTEnablingCUDATestCase): """ Test array reduction methods and functions such as .sum(), .max(), etc. """ @@ -373,3 +381,267 @@ def kernel(out): out = cuda.to_device(np.zeros(len(arrays), dtype=np.float64)) kernel[1, 1](out) self.assertPreciseEqual(expected, out.copy_to_host()) + + def test_median_basic(self): + def variations(a): + # Sorted, reversed, random, many duplicates + yield a + a = a[::-1].copy() + yield a + np.random.shuffle(a) + yield a + a[a % 4 >= 1] = 3.5 + yield a + + self.check_median_basic(array_median_global, variations) + + def check_median_basic(self, pyfunc, array_variations): + # cfunc = jit(nopython=True)(pyfunc) + + def check(arr): + @cuda.jit + def kernel(out): + out[0] = np.median(arr) + + expected = pyfunc(arr) + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + kernel[1, 1](out) + + got = out.copy_to_host()[0] + + self.assertPreciseEqual(expected, got) + + # Empty array case + check(np.array([])) + + # Odd sizes + def check_odd(a): + check(a) + a = a.reshape((9, 7)) + check(a) + check(a.T) + + for a in array_variations(np.arange(63) + 10.5): + check_odd(a) + + # Even sizes + def check_even(a): + check(a) + a = a.reshape((4, 16)) + check(a) + check(a.T) + + for a in array_variations(np.arange(64) + 10.5): + check_even(a) + + def check_percentile_and_quantile(self, pyfunc, q_upper_bound): + def check_array_q(a, q, abs_tol=1e-12): + @cuda.jit + def kernel(out): + result = pyfunc(a, q) + for i in range(len(out)): + out[i] = result[i] + + out = cuda.to_device(np.zeros(len(q), dtype=np.float64)) + kernel[1, 1](out) + + expected = pyfunc(a, q) + got = out.copy_to_host() + + finite = np.isfinite(expected) + if np.all(finite): + self.assertPreciseEqual(got, expected, abs_tol=abs_tol) + else: + self.assertPreciseEqual( + got[finite], expected[finite], abs_tol=abs_tol + ) + + def check_scalar_q(a, q, abs_tol=1e-12): + @cuda.jit + def kernel(out): + out[0] = pyfunc(a, q) + + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + kernel[1, 1](out) + + expected = pyfunc(a, q) + got = out.copy_to_host()[0] + + if np.isfinite(expected): + self.assertPreciseEqual(got, expected, abs_tol=abs_tol) + + a = self.random.randn(27).reshape(3, 3, 3) + q = np.linspace(0, q_upper_bound, 14)[::-1].copy() + + check_array_q(a, q) + check_scalar_q(a, 0) + check_scalar_q(a, q_upper_bound / 2) + check_scalar_q(a, q_upper_bound) + + not_finite = [np.nan, -np.inf, np.inf] + a.flat[:10] = self.random.choice(not_finite, 10) + self.random.shuffle(a) + self.random.shuffle(q) + check_array_q(a, q) + + a = a.flatten().tolist() + q = q.flatten().tolist() + + # TODO - list types + # check_array_q(a, q) + # check(tuple(a), tuple(q)) + + a = self.random.choice([1, 2, 3, 4], 10) + q = np.linspace(0, q_upper_bound, 5) + check_array_q(a, q) + + def test_percentile_basic(self): + pyfunc = np.percentile + self.check_percentile_and_quantile(pyfunc, q_upper_bound=100) + self.check_percentile_and_quantile_edge_cases(pyfunc, q_upper_bound=100) + + @unittest.expectedFailure + def test_percentile_exceptions(self): + pyfunc = np.percentile + self.check_percentile_and_quantile_exceptions(pyfunc) + + def check_percentile_and_quantile_edge_cases( + self, pyfunc, q_upper_bound=100 + ): + # intended to be a faitful reproduction of the upstream numba test + # packing all the test cases into a single kernel for perf + def _array_combinations(elements): + for i in range(1, 10): + for comb in combinations_with_replacement(elements, i): + yield np.array(comb) + + q = (0, 0.1 * q_upper_bound, 0.2 * q_upper_bound, q_upper_bound) + element_pool = (1, -1, np.nan, np.inf, -np.inf) + test_cases = list(_array_combinations(element_pool)) + + max_len = max(len(a) for a in test_cases) + n_cases = len(test_cases) + + # create a block containing all the test cases + # will independently record and pass the lengths + a_batch = np.full((n_cases, max_len), np.nan, dtype=np.float64) + lengths = np.zeros(n_cases, dtype=np.int32) + for i, a in enumerate(test_cases): + a_batch[i, : len(a)] = a + lengths[i] = len(a) + + @cuda.jit + def kernel(a_batch, lengths, q_arr, out): + gid = cuda.grid(1) + if gid < a_batch.shape[0]: + length = lengths[gid] + a_valid = a_batch[gid, :length] + result = np.percentile(a_valid, q_arr) + for j in range(len(result)): + out[gid, j] = result[j] + + q_arr = np.array(q, dtype=np.float64) + out = cuda.to_device(np.zeros((n_cases, len(q)), dtype=np.float64)) + + kernel.forall(len(test_cases))( + cuda.to_device(a_batch), + cuda.to_device(lengths), + cuda.to_device(q_arr), + out, + ) + + got = out.copy_to_host() + for i, a in enumerate(test_cases): + expected = np.percentile(a, q) + finite = np.isfinite(expected) + + if np.all(finite): + self.assertPreciseEqual(got[i], expected, abs_tol=1e-14) + else: + self.assertPreciseEqual( + got[i][finite], expected[finite], abs_tol=1e-14 + ) + + def check_percentile_and_quantile_exceptions(self, pyfunc): + def check_scalar_q_err(a, q, abs_tol=1e-12): + @cuda.jit + def kernel(out): + out[0] = np.percentile(a, q) + + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + with self.assertRaises(ValueError) as raises: + kernel[1, 1](out) + self.assertEqual( + "Percentiles must be in the range [0, 100]", + str(raises.exception), + ) + + # Exceptions leak references + self.disable_leak_check() + a = np.arange(5) + check_scalar_q_err(a, -5) # q less than 0 + check_scalar_q_err(a, 105) + check_scalar_q_err(a, np.nan) + + # complex typing failure + @cuda.jit + def kernel(out): + np.percentile(a, q) + + a = np.arange(5) * 1j + q = 0.1 + + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + with self.assertTypingError(): + kernel[1, 1](out) + + @unittest.expectedFailure + def check_quantile_exceptions(self, pyfunc): + def check_scalar_q_err(a, q, abs_tol=1e-12): + @cuda.jit + def kernel(out): + out[0] = np.percentile(a, q) + + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + with self.assertRaises(ValueError) as raises: + kernel[1, 1](out) + self.assertEqual( + "Quantiles must be in the range [0, 1]", + str(raises.exception), + ) + + # Exceptions leak references + self.disable_leak_check() + a = np.arange(5) + check_scalar_q_err(a, -0.5) # q less than 0 + check_scalar_q_err(a, 1.05) + check_scalar_q_err(a, np.nan) + + # complex typing failure + @cuda.jit + def kernel(out): + np.quantile(a, q) + + a = np.arange(5) * 1j + q = 0.1 + + out = cuda.to_device(np.zeros(1, dtype=np.float64)) + with self.assertTypingError(): + kernel[1, 1](out) + + def test_quantile_basic(self): + pyfunc = np.quantile + self.check_percentile_and_quantile(pyfunc, q_upper_bound=1) + self.check_percentile_and_quantile_edge_cases(pyfunc, q_upper_bound=1) + + def test_nanpercentile_basic(self): + pyfunc = np.nanpercentile + self.check_percentile_and_quantile(pyfunc, q_upper_bound=100) + self.check_percentile_and_quantile_edge_cases(pyfunc, q_upper_bound=100) + self.check_percentile_and_quantile_exceptions(pyfunc) + + def test_nanquantile_basic(self): + pyfunc = np.nanquantile + self.check_percentile_and_quantile(pyfunc, q_upper_bound=1) + self.check_percentile_and_quantile_edge_cases(pyfunc, q_upper_bound=1) + self.check_quantile_exceptions(pyfunc) diff --git a/numba_cuda/numba/cuda/tests/test_array_methods.py b/numba_cuda/numba/cuda/tests/test_array_methods.py new file mode 100644 index 000000000..407a44792 --- /dev/null +++ b/numba_cuda/numba/cuda/tests/test_array_methods.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause +import numpy as np + +from numba.tests.support import TestCase, MemoryLeakMixin +from numba import cuda +from numba.cuda import config + + +class TestArrayMethods(MemoryLeakMixin, TestCase): + """ + Test array reduction methods and functions such as .sum(), .max(), etc. + """ + + def setUp(self): + super(TestArrayMethods, self).setUp() + np.random.seed(42) + self.old_nrt_setting = config.CUDA_ENABLE_NRT + config.CUDA_ENABLE_NRT = True + + def tearDown(self): + config.CUDA_ENABLE_NRT = self.old_nrt_setting + super(TestArrayMethods, self).tearDown() + + def test_array_copy(self): + ary = np.array([1.0, 2.0, 3.0]) + out = cuda.to_device(np.zeros(3)) + + @cuda.jit + def kernel(out): + gid = cuda.grid(1) + if gid < 1: + cpy = ary.copy() + for i in range(len(out)): + out[i] = cpy[i] + + kernel[1, 1](out) + + result = out.copy_to_host() + np.testing.assert_array_equal(result, ary) diff --git a/numba_cuda/numba/cuda/ufuncs.py b/numba_cuda/numba/cuda/ufuncs.py index dc86a455b..9c1728225 100644 --- a/numba_cuda/numba/cuda/ufuncs.py +++ b/numba_cuda/numba/cuda/ufuncs.py @@ -743,4 +743,49 @@ def np_real_atanh_impl(context, builder, sig, args): "DD->D": npyfuncs.np_complex_div_impl, } + db[np.isfinite] = { + "f->?": npyfuncs.np_real_isfinite_impl, + "d->?": npyfuncs.np_real_isfinite_impl, + "F->?": npyfuncs.np_complex_isfinite_impl, + "D->?": npyfuncs.np_complex_isfinite_impl, + # int8 + "b->?": npyfuncs.np_int_isfinite_impl, + "B->?": npyfuncs.np_int_isfinite_impl, + # int16 + "h->?": npyfuncs.np_int_isfinite_impl, + "H->?": npyfuncs.np_int_isfinite_impl, + # int32 + "i->?": npyfuncs.np_int_isfinite_impl, + "I->?": npyfuncs.np_int_isfinite_impl, + # int64 + "l->?": npyfuncs.np_int_isfinite_impl, + "L->?": npyfuncs.np_int_isfinite_impl, + # intp + "q->?": npyfuncs.np_int_isfinite_impl, + "Q->?": npyfuncs.np_int_isfinite_impl, + # boolean + "?->?": npyfuncs.np_int_isfinite_impl, + # datetime & timedelta + "M->?": npyfuncs.np_datetime_isfinite_impl, + "m->?": npyfuncs.np_datetime_isfinite_impl, + } + + db[np.multiply] = { + "??->?": numbers.int_and_impl, + "bb->b": numbers.int_mul_impl, + "BB->B": numbers.int_mul_impl, + "hh->h": numbers.int_mul_impl, + "HH->H": numbers.int_mul_impl, + "ii->i": numbers.int_mul_impl, + "II->I": numbers.int_mul_impl, + "ll->l": numbers.int_mul_impl, + "LL->L": numbers.int_mul_impl, + "qq->q": numbers.int_mul_impl, + "QQ->Q": numbers.int_mul_impl, + "ff->f": numbers.real_mul_impl, + "dd->d": numbers.real_mul_impl, + "FF->F": numbers.complex_mul_impl, + "DD->D": numbers.complex_mul_impl, + } + return db