From 69b0a02553f7f69f538fdc650e394a899be66598 Mon Sep 17 00:00:00 2001 From: Atmn Patel Date: Tue, 5 Aug 2025 16:38:43 -0700 Subject: [PATCH 01/10] Vendor in cpython math impls --- numba_cuda/numba/cuda/cpython/cmathimpl.py | 556 +++++++ numba_cuda/numba/cuda/cpython/mathimpl.py | 497 ++++++ numba_cuda/numba/cuda/cpython/numbers.py | 1471 +++++++++++++++++ numba_cuda/numba/cuda/target.py | 4 +- .../numba/cuda/tests/nocuda/test_import.py | 3 + numba_cuda/numba/cuda/ufuncs.py | 2 +- 6 files changed, 2530 insertions(+), 3 deletions(-) create mode 100644 numba_cuda/numba/cuda/cpython/cmathimpl.py create mode 100644 numba_cuda/numba/cuda/cpython/mathimpl.py create mode 100644 numba_cuda/numba/cuda/cpython/numbers.py diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py new file mode 100644 index 000000000..7af62b08e --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -0,0 +1,556 @@ +""" +Implement the cmath module functions. +""" + +import cmath +import math + +from numba.core.imputils import Registry, impl_ret_untracked +from numba.core import types +from numba.core.typing import signature +from numba.cuda.cpython import mathimpl +from numba.core.extending import overload + +registry = Registry("cmathimpl") +lower = registry.lower + + +def is_nan(builder, z): + return builder.fcmp_unordered("uno", z.real, z.imag) + + +def is_inf(builder, z): + return builder.or_( + mathimpl.is_inf(builder, z.real), mathimpl.is_inf(builder, z.imag) + ) + + +def is_finite(builder, z): + return builder.and_( + mathimpl.is_finite(builder, z.real), mathimpl.is_finite(builder, z.imag) + ) + + +@lower(cmath.isnan, types.Complex) +def isnan_float_impl(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + res = is_nan(builder, z) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(cmath.isinf, types.Complex) +def isinf_float_impl(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + res = is_inf(builder, z) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(cmath.isfinite, types.Complex) +def isfinite_float_impl(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + res = is_finite(builder, z) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@overload(cmath.rect) +def impl_cmath_rect(r, phi): + if all([isinstance(typ, types.Float) for typ in [r, phi]]): + + def impl(r, phi): + if not math.isfinite(phi): + if not r: + # cmath.rect(0, phi={inf, nan}) = 0 + return abs(r) + if math.isinf(r): + # cmath.rect(inf, phi={inf, nan}) = inf + j phi + return complex(r, phi) + real = math.cos(phi) + imag = math.sin(phi) + if real == 0.0 and math.isinf(r): + # 0 * inf would return NaN, we want to keep 0 but xor the sign + real /= r + else: + real *= r + if imag == 0.0 and math.isinf(r): + # ditto + imag /= r + else: + imag *= r + return complex(real, imag) + + return impl + + +def intrinsic_complex_unary(inner_func): + def wrapper(context, builder, sig, args): + [typ] = sig.args + [value] = args + z = context.make_complex(builder, typ, value=value) + x = z.real + y = z.imag + # Same as above: math.isfinite() is unavailable on 2.x so we precompute + # its value and pass it to the pure Python implementation. + x_is_finite = mathimpl.is_finite(builder, x) + y_is_finite = mathimpl.is_finite(builder, y) + inner_sig = signature( + sig.return_type, *(typ.underlying_float,) * 2 + (types.boolean,) * 2 + ) + res = context.compile_internal( + builder, inner_func, inner_sig, (x, y, x_is_finite, y_is_finite) + ) + return impl_ret_untracked(context, builder, sig, res) + + return wrapper + + +NAN = float("nan") +INF = float("inf") + + +@lower(cmath.exp, types.Complex) +@intrinsic_complex_unary +def exp_impl(x, y, x_is_finite, y_is_finite): + """cmath.exp(x + y j)""" + if x_is_finite: + if y_is_finite: + c = math.cos(y) + s = math.sin(y) + r = math.exp(x) + return complex(r * c, r * s) + else: + return complex(NAN, NAN) + elif math.isnan(x): + if y: + return complex(x, x) # nan + j nan + else: + return complex(x, y) # nan + 0j + elif x > 0.0: + # x == +inf + if y_is_finite: + real = math.cos(y) + imag = math.sin(y) + # Avoid NaNs if math.cos(y) or math.sin(y) == 0 + # (e.g. cmath.exp(inf + 0j) == inf + 0j) + if real != 0: + real *= x + if imag != 0: + imag *= x + return complex(real, imag) + else: + return complex(x, NAN) + else: + # x == -inf + if y_is_finite: + r = math.exp(x) + c = math.cos(y) + s = math.sin(y) + return complex(r * c, r * s) + else: + r = 0 + return complex(r, r) + + +@lower(cmath.log, types.Complex) +@intrinsic_complex_unary +def log_impl(x, y, x_is_finite, y_is_finite): + """cmath.log(x + y j)""" + a = math.log(math.hypot(x, y)) + b = math.atan2(y, x) + return complex(a, b) + + +@lower(cmath.log, types.Complex, types.Complex) +def log_base_impl(context, builder, sig, args): + """cmath.log(z, base)""" + [z, base] = args + + def log_base(z, base): + return cmath.log(z) / cmath.log(base) + + res = context.compile_internal(builder, log_base, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@overload(cmath.log10) +def impl_cmath_log10(z): + if not isinstance(z, types.Complex): + return + + LN_10 = 2.302585092994045684 + + def log10_impl(z): + """cmath.log10(z)""" + z = cmath.log(z) + # This formula gives better results on +/-inf than cmath.log(z, 10) + # See http://bugs.python.org/issue22544 + return complex(z.real / LN_10, z.imag / LN_10) + + return log10_impl + + +@overload(cmath.phase) +def phase_impl(x): + """cmath.phase(x + y j)""" + + if not isinstance(x, types.Complex): + return + + def impl(x): + return math.atan2(x.imag, x.real) + + return impl + + +@overload(cmath.polar) +def polar_impl(x): + if not isinstance(x, types.Complex): + return + + def impl(x): + r, i = x.real, x.imag + return math.hypot(r, i), math.atan2(i, r) + + return impl + + +@lower(cmath.sqrt, types.Complex) +def sqrt_impl(context, builder, sig, args): + # We risk spurious overflow for components >= FLT_MAX / (1 + sqrt(2)). + + SQRT2 = 1.414213562373095048801688724209698079e0 + ONE_PLUS_SQRT2 = 1.0 + SQRT2 + theargflt = sig.args[0].underlying_float + # Get a type specific maximum value so scaling for overflow is based on that + MAX = mathimpl.DBL_MAX if theargflt.bitwidth == 64 else mathimpl.FLT_MAX + # THRES will be double precision, should not impact typing as it's just + # used for comparison, there *may* be a few values near THRES which + # deviate from e.g. NumPy due to rounding that occurs in the computation + # of this value in the case of a 32bit argument. + THRES = MAX / ONE_PLUS_SQRT2 + + def sqrt_impl(z): + """cmath.sqrt(z)""" + # This is NumPy's algorithm, see npy_csqrt() in npy_math_complex.c.src + a = z.real + b = z.imag + if a == 0.0 and b == 0.0: + return complex(abs(b), b) + if math.isinf(b): + return complex(abs(b), b) + if math.isnan(a): + return complex(a, a) + if math.isinf(a): + if a < 0.0: + return complex(abs(b - b), math.copysign(a, b)) + else: + return complex(a, math.copysign(b - b, b)) + + # The remaining special case (b is NaN) is handled just fine by + # the normal code path below. + + # Scale to avoid overflow + if abs(a) >= THRES or abs(b) >= THRES: + a *= 0.25 + b *= 0.25 + scale = True + else: + scale = False + # Algorithm 312, CACM vol 10, Oct 1967 + if a >= 0: + t = math.sqrt((a + math.hypot(a, b)) * 0.5) + real = t + imag = b / (2 * t) + else: + t = math.sqrt((-a + math.hypot(a, b)) * 0.5) + real = abs(b) / (2 * t) + imag = math.copysign(t, b) + # Rescale + if scale: + return complex(real * 2, imag) + else: + return complex(real, imag) + + res = context.compile_internal(builder, sqrt_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@lower(cmath.cos, types.Complex) +def cos_impl(context, builder, sig, args): + def cos_impl(z): + """cmath.cos(z) = cmath.cosh(z j)""" + return cmath.cosh(complex(-z.imag, z.real)) + + res = context.compile_internal(builder, cos_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@overload(cmath.cosh) +def impl_cmath_cosh(z): + if not isinstance(z, types.Complex): + return + + def cosh_impl(z): + """cmath.cosh(z)""" + x = z.real + y = z.imag + if math.isinf(x): + if math.isnan(y): + # x = +inf, y = NaN => cmath.cosh(x + y j) = inf + Nan * j + real = abs(x) + imag = y + elif y == 0.0: + # x = +inf, y = 0 => cmath.cosh(x + y j) = inf + 0j + real = abs(x) + imag = y + else: + real = math.copysign(x, math.cos(y)) + imag = math.copysign(x, math.sin(y)) + if x < 0.0: + # x = -inf => negate imaginary part of result + imag = -imag + return complex(real, imag) + return complex(math.cos(y) * math.cosh(x), math.sin(y) * math.sinh(x)) + + return cosh_impl + + +@lower(cmath.sin, types.Complex) +def sin_impl(context, builder, sig, args): + def sin_impl(z): + """cmath.sin(z) = -j * cmath.sinh(z j)""" + r = cmath.sinh(complex(-z.imag, z.real)) + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, sin_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@overload(cmath.sinh) +def impl_cmath_sinh(z): + if not isinstance(z, types.Complex): + return + + def sinh_impl(z): + """cmath.sinh(z)""" + x = z.real + y = z.imag + if math.isinf(x): + if math.isnan(y): + # x = +/-inf, y = NaN => cmath.sinh(x + y j) = x + NaN * j + real = x + imag = y + else: + real = math.cos(y) + imag = math.sin(y) + if real != 0.0: + real *= x + if imag != 0.0: + imag *= abs(x) + return complex(real, imag) + return complex(math.cos(y) * math.sinh(x), math.sin(y) * math.cosh(x)) + + return sinh_impl + + +@lower(cmath.tan, types.Complex) +def tan_impl(context, builder, sig, args): + def tan_impl(z): + """cmath.tan(z) = -j * cmath.tanh(z j)""" + r = cmath.tanh(complex(-z.imag, z.real)) + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, tan_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@overload(cmath.tanh) +def impl_cmath_tanh(z): + if not isinstance(z, types.Complex): + return + + def tanh_impl(z): + """cmath.tanh(z)""" + x = z.real + y = z.imag + if math.isinf(x): + real = math.copysign(1.0, x) + if math.isinf(y): + imag = 0.0 + else: + imag = math.copysign(0.0, math.sin(2.0 * y)) + return complex(real, imag) + # This is CPython's algorithm (see c_tanh() in cmathmodule.c). + # XXX how to force float constants into single precision? + tx = math.tanh(x) + ty = math.tan(y) + cx = 1.0 / math.cosh(x) + txty = tx * ty + denom = 1.0 + txty * txty + return complex(tx * (1.0 + ty * ty) / denom, ((ty / denom) * cx) * cx) + + return tanh_impl + + +@lower(cmath.acos, types.Complex) +def acos_impl(context, builder, sig, args): + LN_4 = math.log(4) + THRES = mathimpl.FLT_MAX / 4 + + def acos_impl(z): + """cmath.acos(z)""" + # CPython's algorithm (see c_acos() in cmathmodule.c) + if abs(z.real) > THRES or abs(z.imag) > THRES: + # Avoid unnecessary overflow for large arguments + # (also handles infinities gracefully) + real = math.atan2(abs(z.imag), z.real) + imag = math.copysign( + math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4, -z.imag + ) + return complex(real, imag) + else: + s1 = cmath.sqrt(complex(1.0 - z.real, -z.imag)) + s2 = cmath.sqrt(complex(1.0 + z.real, z.imag)) + real = 2.0 * math.atan2(s1.real, s2.real) + imag = math.asinh(s2.real * s1.imag - s2.imag * s1.real) + return complex(real, imag) + + res = context.compile_internal(builder, acos_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@overload(cmath.acosh) +def impl_cmath_acosh(z): + if not isinstance(z, types.Complex): + return + + LN_4 = math.log(4) + THRES = mathimpl.FLT_MAX / 4 + + def acosh_impl(z): + """cmath.acosh(z)""" + # CPython's algorithm (see c_acosh() in cmathmodule.c) + if abs(z.real) > THRES or abs(z.imag) > THRES: + # Avoid unnecessary overflow for large arguments + # (also handles infinities gracefully) + real = math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4 + imag = math.atan2(z.imag, z.real) + return complex(real, imag) + else: + s1 = cmath.sqrt(complex(z.real - 1.0, z.imag)) + s2 = cmath.sqrt(complex(z.real + 1.0, z.imag)) + real = math.asinh(s1.real * s2.real + s1.imag * s2.imag) + imag = 2.0 * math.atan2(s1.imag, s2.real) + return complex(real, imag) + # Condensed formula (NumPy) + # return cmath.log(z + cmath.sqrt(z + 1.) * cmath.sqrt(z - 1.)) + + return acosh_impl + + +@lower(cmath.asinh, types.Complex) +def asinh_impl(context, builder, sig, args): + LN_4 = math.log(4) + THRES = mathimpl.FLT_MAX / 4 + + def asinh_impl(z): + """cmath.asinh(z)""" + # CPython's algorithm (see c_asinh() in cmathmodule.c) + if abs(z.real) > THRES or abs(z.imag) > THRES: + real = math.copysign( + math.log(math.hypot(z.real * 0.5, z.imag * 0.5)) + LN_4, z.real + ) + imag = math.atan2(z.imag, abs(z.real)) + return complex(real, imag) + else: + s1 = cmath.sqrt(complex(1.0 + z.imag, -z.real)) + s2 = cmath.sqrt(complex(1.0 - z.imag, z.real)) + real = math.asinh(s1.real * s2.imag - s2.real * s1.imag) + imag = math.atan2(z.imag, s1.real * s2.real - s1.imag * s2.imag) + return complex(real, imag) + + res = context.compile_internal(builder, asinh_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@lower(cmath.asin, types.Complex) +def asin_impl(context, builder, sig, args): + def asin_impl(z): + """cmath.asin(z) = -j * cmath.asinh(z j)""" + r = cmath.asinh(complex(-z.imag, z.real)) + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, asin_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@lower(cmath.atan, types.Complex) +def atan_impl(context, builder, sig, args): + def atan_impl(z): + """cmath.atan(z) = -j * cmath.atanh(z j)""" + r = cmath.atanh(complex(-z.imag, z.real)) + if math.isinf(z.real) and math.isnan(z.imag): + # XXX this is odd but necessary + return complex(r.imag, r.real) + else: + return complex(r.imag, -r.real) + + res = context.compile_internal(builder, atan_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) + + +@lower(cmath.atanh, types.Complex) +def atanh_impl(context, builder, sig, args): + THRES_LARGE = math.sqrt(mathimpl.FLT_MAX / 4) + THRES_SMALL = math.sqrt(mathimpl.FLT_MIN) + PI_12 = math.pi / 2 + + def atanh_impl(z): + """cmath.atanh(z)""" + # CPython's algorithm (see c_atanh() in cmathmodule.c) + if z.real < 0.0: + # Reduce to case where z.real >= 0., using atanh(z) = -atanh(-z). + negate = True + z = -z + else: + negate = False + + ay = abs(z.imag) + if math.isnan(z.real) or z.real > THRES_LARGE or ay > THRES_LARGE: + if math.isinf(z.imag): + real = math.copysign(0.0, z.real) + elif math.isinf(z.real): + real = 0.0 + else: + # may be safe from overflow, depending on hypot's implementation... + h = math.hypot(z.real * 0.5, z.imag * 0.5) + real = z.real / 4.0 / h / h + imag = -math.copysign(PI_12, -z.imag) + elif z.real == 1.0 and ay < THRES_SMALL: + # C99 standard says: atanh(1+/-0.) should be inf +/- 0j + if ay == 0.0: + real = INF + imag = z.imag + else: + real = -math.log(math.sqrt(ay) / math.sqrt(math.hypot(ay, 2.0))) + imag = math.copysign(math.atan2(2.0, -ay) / 2, z.imag) + else: + sqay = ay * ay + zr1 = 1 - z.real + real = math.log1p(4.0 * z.real / (zr1 * zr1 + sqay)) * 0.25 + imag = -math.atan2(-2.0 * z.imag, zr1 * (1 + z.real) - sqay) * 0.5 + + if math.isnan(z.imag): + imag = NAN + if negate: + return complex(-real, -imag) + else: + return complex(real, imag) + + res = context.compile_internal(builder, atanh_impl, sig, args) + return impl_ret_untracked(context, builder, sig, res) diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py new file mode 100644 index 000000000..b996f816d --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -0,0 +1,497 @@ +""" +Provide math calls that uses intrinsics or libc math functions. +""" + +import math +import operator +import sys +import numpy as np + +import llvmlite.ir +from llvmlite.ir import Constant + +from numba.core.imputils import Registry, impl_ret_untracked +from numba.core import types, config +from numba.core.extending import overload +from numba.core.typing import signature +from numba.cpython.unsafe.numbers import trailing_zeros +from numba.cuda import cgutils + + +registry = Registry("mathimpl") +lower = registry.lower + + +# Helpers, shared with cmathimpl. +_NP_FLT_FINFO = np.finfo(np.dtype("float32")) +FLT_MAX = _NP_FLT_FINFO.max +FLT_MIN = _NP_FLT_FINFO.tiny + +_NP_DBL_FINFO = np.finfo(np.dtype("float64")) +DBL_MAX = _NP_DBL_FINFO.max +DBL_MIN = _NP_DBL_FINFO.tiny + +FLOAT_ABS_MASK = 0x7FFFFFFF +FLOAT_SIGN_MASK = 0x80000000 +DOUBLE_ABS_MASK = 0x7FFFFFFFFFFFFFFF +DOUBLE_SIGN_MASK = 0x8000000000000000 + + +def is_nan(builder, val): + """ + Return a condition testing whether *val* is a NaN. + """ + return builder.fcmp_unordered("uno", val, val) + + +def is_inf(builder, val): + """ + Return a condition testing whether *val* is an infinite. + """ + pos_inf = Constant(val.type, float("+inf")) + neg_inf = Constant(val.type, float("-inf")) + isposinf = builder.fcmp_ordered("==", val, pos_inf) + isneginf = builder.fcmp_ordered("==", val, neg_inf) + return builder.or_(isposinf, isneginf) + + +def is_finite(builder, val): + """ + Return a condition testing whether *val* is a finite. + """ + # is_finite(x) <=> x - x != NaN + val_minus_val = builder.fsub(val, val) + return builder.fcmp_ordered("ord", val_minus_val, val_minus_val) + + +def f64_as_int64(builder, val): + """ + Bitcast a double into a 64-bit integer. + """ + assert val.type == llvmlite.ir.DoubleType() + return builder.bitcast(val, llvmlite.ir.IntType(64)) + + +def int64_as_f64(builder, val): + """ + Bitcast a 64-bit integer into a double. + """ + assert val.type == llvmlite.ir.IntType(64) + return builder.bitcast(val, llvmlite.ir.DoubleType()) + + +def f32_as_int32(builder, val): + """ + Bitcast a float into a 32-bit integer. + """ + assert val.type == llvmlite.ir.FloatType() + return builder.bitcast(val, llvmlite.ir.IntType(32)) + + +def int32_as_f32(builder, val): + """ + Bitcast a 32-bit integer into a float. + """ + assert val.type == llvmlite.ir.IntType(32) + return builder.bitcast(val, llvmlite.ir.FloatType()) + + +def negate_real(builder, val): + """ + Negate real number *val*, with proper handling of zeros. + """ + # The negative zero forces LLVM to handle signed zeros properly. + return builder.fsub(Constant(val.type, -0.0), val) + + +def call_fp_intrinsic(builder, name, args): + """ + Call a LLVM intrinsic floating-point operation. + """ + mod = builder.module + intr = mod.declare_intrinsic(name, [a.type for a in args]) + return builder.call(intr, args) + + +def _unary_int_input_wrapper_impl(wrapped_impl): + """ + Return an implementation factory to convert the single integral input + argument to a float64, then defer to the *wrapped_impl*. + """ + + def implementer(context, builder, sig, args): + (val,) = args + input_type = sig.args[0] + fpval = context.cast(builder, val, input_type, types.float64) + inner_sig = signature(types.float64, types.float64) + res = wrapped_impl(context, builder, inner_sig, (fpval,)) + return context.cast(builder, res, types.float64, sig.return_type) + + return implementer + + +def unary_math_int_impl(fn, float_impl): + impl = _unary_int_input_wrapper_impl(float_impl) + lower(fn, types.Integer)(impl) + + +def unary_math_intr(fn, intrcode): + """ + Implement the math function *fn* using the LLVM intrinsic *intrcode*. + """ + + @lower(fn, types.Float) + def float_impl(context, builder, sig, args): + res = call_fp_intrinsic(builder, intrcode, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + unary_math_int_impl(fn, float_impl) + return float_impl + + +def unary_math_extern(fn, f32extern, f64extern, int_restype=False): + """ + Register implementations of Python function *fn* using the + external function named *f32extern* and *f64extern* (for float32 + and float64 inputs, respectively). + If *int_restype* is true, then the function's return value should be + integral, otherwise floating-point. + """ + + def float_impl(context, builder, sig, args): + """ + Implement *fn* for a types.Float input. + """ + [val] = args + input_type = sig.args[0] + lty = context.get_value_type(input_type) + func_name = { + types.float32: f32extern, + types.float64: f64extern, + }[input_type] + fnty = llvmlite.ir.FunctionType(lty, [lty]) + fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name) + res = builder.call(fn, (val,)) + res = context.cast(builder, res, input_type, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + lower(fn, types.Float)(float_impl) + + # Implement wrapper for integer inputs + unary_math_int_impl(fn, float_impl) + + return float_impl + + +unary_math_intr(math.fabs, "llvm.fabs") +exp_impl = unary_math_intr(math.exp, "llvm.exp") +log_impl = unary_math_intr(math.log, "llvm.log") +log10_impl = unary_math_intr(math.log10, "llvm.log10") +log2_impl = unary_math_intr(math.log2, "llvm.log2") +sin_impl = unary_math_intr(math.sin, "llvm.sin") +cos_impl = unary_math_intr(math.cos, "llvm.cos") + +log1p_impl = unary_math_extern(math.log1p, "log1pf", "log1p") +expm1_impl = unary_math_extern(math.expm1, "expm1f", "expm1") +erf_impl = unary_math_extern(math.erf, "erff", "erf") +erfc_impl = unary_math_extern(math.erfc, "erfcf", "erfc") + +tan_impl = unary_math_extern(math.tan, "tanf", "tan") +asin_impl = unary_math_extern(math.asin, "asinf", "asin") +acos_impl = unary_math_extern(math.acos, "acosf", "acos") +atan_impl = unary_math_extern(math.atan, "atanf", "atan") + +asinh_impl = unary_math_extern(math.asinh, "asinhf", "asinh") +acosh_impl = unary_math_extern(math.acosh, "acoshf", "acosh") +atanh_impl = unary_math_extern(math.atanh, "atanhf", "atanh") +sinh_impl = unary_math_extern(math.sinh, "sinhf", "sinh") +cosh_impl = unary_math_extern(math.cosh, "coshf", "cosh") +tanh_impl = unary_math_extern(math.tanh, "tanhf", "tanh") + +log2_impl = unary_math_extern(math.log2, "log2f", "log2") +ceil_impl = unary_math_extern(math.ceil, "ceilf", "ceil", True) +floor_impl = unary_math_extern(math.floor, "floorf", "floor", True) + +gamma_impl = unary_math_extern( + math.gamma, "numba_gammaf", "numba_gamma" +) # work-around +sqrt_impl = unary_math_extern(math.sqrt, "sqrtf", "sqrt") +trunc_impl = unary_math_extern(math.trunc, "truncf", "trunc", True) +lgamma_impl = unary_math_extern(math.lgamma, "lgammaf", "lgamma") + + +@lower(math.isnan, types.Float) +def isnan_float_impl(context, builder, sig, args): + [val] = args + res = is_nan(builder, val) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.isnan, types.Integer) +def isnan_int_impl(context, builder, sig, args): + res = cgutils.false_bit + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.isinf, types.Float) +def isinf_float_impl(context, builder, sig, args): + [val] = args + res = is_inf(builder, val) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.isinf, types.Integer) +def isinf_int_impl(context, builder, sig, args): + res = cgutils.false_bit + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.isfinite, types.Float) +def isfinite_float_impl(context, builder, sig, args): + [val] = args + res = is_finite(builder, val) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.isfinite, types.Integer) +def isfinite_int_impl(context, builder, sig, args): + res = cgutils.true_bit + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.copysign, types.Float, types.Float) +def copysign_float_impl(context, builder, sig, args): + lty = args[0].type + mod = builder.module + fn = cgutils.get_or_insert_function( + mod, + llvmlite.ir.FunctionType(lty, (lty, lty)), + "llvm.copysign.%s" % lty.intrinsic_name, + ) + res = builder.call(fn, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +@lower(math.frexp, types.Float) +def frexp_impl(context, builder, sig, args): + (val,) = args + fltty = context.get_data_type(sig.args[0]) + intty = context.get_data_type(sig.return_type[1]) + expptr = cgutils.alloca_once(builder, intty, name="exp") + fnty = llvmlite.ir.FunctionType( + fltty, (fltty, llvmlite.ir.PointerType(intty)) + ) + fname = { + "float": "numba_frexpf", + "double": "numba_frexp", + }[str(fltty)] + fn = cgutils.get_or_insert_function(builder.module, fnty, fname) + res = builder.call(fn, (val, expptr)) + res = cgutils.make_anonymous_struct(builder, (res, builder.load(expptr))) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.ldexp, types.Float, types.intc) +def ldexp_impl(context, builder, sig, args): + val, exp = args + fltty, intty = map(context.get_data_type, sig.args) + fnty = llvmlite.ir.FunctionType(fltty, (fltty, intty)) + fname = { + "float": "numba_ldexpf", + "double": "numba_ldexp", + }[str(fltty)] + fn = cgutils.insert_pure_function(builder.module, fnty, name=fname) + res = builder.call(fn, (val, exp)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +@lower(math.atan2, types.int64, types.int64) +def atan2_s64_impl(context, builder, sig, args): + [y, x] = args + y = builder.sitofp(y, llvmlite.ir.DoubleType()) + x = builder.sitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + return atan2_float_impl(context, builder, fsig, (y, x)) + + +@lower(math.atan2, types.uint64, types.uint64) +def atan2_u64_impl(context, builder, sig, args): + [y, x] = args + y = builder.uitofp(y, llvmlite.ir.DoubleType()) + x = builder.uitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + return atan2_float_impl(context, builder, fsig, (y, x)) + + +@lower(math.atan2, types.Float, types.Float) +def atan2_float_impl(context, builder, sig, args): + assert len(args) == 2 + ty = sig.args[0] + lty = context.get_value_type(ty) + func_name = {types.float32: "atan2f", types.float64: "atan2"}[ty] + fnty = llvmlite.ir.FunctionType(lty, (lty, lty)) + fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name) + res = builder.call(fn, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +@lower(math.hypot, types.int64, types.int64) +def hypot_s64_impl(context, builder, sig, args): + [x, y] = args + y = builder.sitofp(y, llvmlite.ir.DoubleType()) + x = builder.sitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + res = hypot_float_impl(context, builder, fsig, (x, y)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.hypot, types.uint64, types.uint64) +def hypot_u64_impl(context, builder, sig, args): + [x, y] = args + y = builder.sitofp(y, llvmlite.ir.DoubleType()) + x = builder.sitofp(x, llvmlite.ir.DoubleType()) + fsig = signature(types.float64, types.float64, types.float64) + res = hypot_float_impl(context, builder, fsig, (x, y)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower(math.hypot, types.Float, types.Float) +def hypot_float_impl(context, builder, sig, args): + xty, yty = sig.args + assert xty == yty == sig.return_type + x, y = args + + # Windows has alternate names for hypot/hypotf, see + # https://msdn.microsoft.com/fr-fr/library/a9yb3dbt%28v=vs.80%29.aspx + fname = { + types.float32: "_hypotf" if sys.platform == "win32" else "hypotf", + types.float64: "_hypot" if sys.platform == "win32" else "hypot", + }[xty] + plat_hypot = types.ExternalFunction(fname, sig) + + if sys.platform == "win32" and config.MACHINE_BITS == 32: + inf = xty(float("inf")) + + def hypot_impl(x, y): + if math.isinf(x) or math.isinf(y): + return inf + return plat_hypot(x, y) + else: + + def hypot_impl(x, y): + return plat_hypot(x, y) + + res = context.compile_internal(builder, hypot_impl, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +@lower(math.radians, types.Float) +def radians_float_impl(context, builder, sig, args): + [x] = args + coef = context.get_constant(sig.return_type, math.pi / 180) + res = builder.fmul(x, coef) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +unary_math_int_impl(math.radians, radians_float_impl) + +# ----------------------------------------------------------------------------- + + +@lower(math.degrees, types.Float) +def degrees_float_impl(context, builder, sig, args): + [x] = args + coef = context.get_constant(sig.return_type, 180 / math.pi) + res = builder.fmul(x, coef) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +unary_math_int_impl(math.degrees, degrees_float_impl) + +# ----------------------------------------------------------------------------- + + +@lower(math.pow, types.Float, types.Float) +@lower(math.pow, types.Float, types.Integer) +def pow_impl(context, builder, sig, args): + impl = context.get_function(operator.pow, sig) + return impl(builder, args) + + +# ----------------------------------------------------------------------------- + + +@lower(math.nextafter, types.Float, types.Float) +def nextafter_impl(context, builder, sig, args): + assert len(args) == 2 + ty = sig.args[0] + lty = context.get_value_type(ty) + func_name = {types.float32: "nextafterf", types.float64: "nextafter"}[ty] + fnty = llvmlite.ir.FunctionType(lty, (lty, lty)) + fn = cgutils.insert_pure_function(builder.module, fnty, name=func_name) + res = builder.call(fn, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +# ----------------------------------------------------------------------------- + + +def _unsigned(T): + """Convert integer to unsigned integer of equivalent width.""" + pass + + +@overload(_unsigned) +def _unsigned_impl(T): + if T in types.unsigned_domain: + return lambda T: T + elif T in types.signed_domain: + newT = getattr(types, "uint{}".format(T.bitwidth)) + return lambda T: newT(T) + + +def gcd_impl(context, builder, sig, args): + xty, yty = sig.args + assert xty == yty == sig.return_type + x, y = args + + def gcd(a, b): + """ + Stein's algorithm, heavily cribbed from Julia implementation. + """ + T = type(a) + if a == 0: + return abs(b) + if b == 0: + return abs(a) + za = trailing_zeros(a) + zb = trailing_zeros(b) + k = min(za, zb) + # Uses np.*_shift instead of operators due to return types + u = _unsigned(abs(np.right_shift(a, za))) + v = _unsigned(abs(np.right_shift(b, zb))) + while u != v: + if u > v: + u, v = v, u + v -= u + v = np.right_shift(v, trailing_zeros(v)) + r = np.left_shift(T(u), k) + return r + + res = context.compile_internal(builder, gcd, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +lower(math.gcd, types.Integer, types.Integer)(gcd_impl) diff --git a/numba_cuda/numba/cuda/cpython/numbers.py b/numba_cuda/numba/cuda/cpython/numbers.py new file mode 100644 index 000000000..4d4621e48 --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/numbers.py @@ -0,0 +1,1471 @@ +import math +import numbers + +import numpy as np +import operator + +from llvmlite import ir +from llvmlite.ir import Constant + +from numba.core.imputils import ( + lower_builtin, + lower_getattr, + lower_cast, + lower_constant, + impl_ret_untracked, +) +from numba.core import typing, types, errors +from numba.core.extending import overload_method +from numba.cpython.unsafe.numbers import viewer +from numba.cuda import cgutils + + +def _int_arith_flags(rettype): + """ + Return the modifier flags for integer arithmetic. + """ + if rettype.signed: + # Ignore the effects of signed overflow. This is important for + # optimization of some indexing operations. For example + # array[i+1] could see `i+1` trigger a signed overflow and + # give a negative number. With Python's indexing, a negative + # index is treated differently: its resolution has a runtime cost. + # Telling LLVM to ignore signed overflows allows it to optimize + # away the check for a negative `i+1` if it knows `i` is positive. + return ["nsw"] + else: + return [] + + +def int_add_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + res = builder.add(a, b, flags=_int_arith_flags(sig.return_type)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sub_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + res = builder.sub(a, b, flags=_int_arith_flags(sig.return_type)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_mul_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + res = builder.mul(a, b, flags=_int_arith_flags(sig.return_type)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_divmod_signed(context, builder, ty, x, y): + """ + Reference Objects/intobject.c + xdivy = x / y; + xmody = (long)(x - (unsigned long)xdivy * y); + /* If the signs of x and y differ, and the remainder is non-0, + * C89 doesn't define whether xdivy is now the floor or the + * ceiling of the infinitely precise quotient. We want the floor, + * and we have it iff the remainder's sign matches y's. + */ + if (xmody && ((y ^ xmody) < 0) /* i.e. and signs differ */) { + xmody += y; + --xdivy; + assert(xmody && ((y ^ xmody) >= 0)); + } + *p_xdivy = xdivy; + *p_xmody = xmody; + """ + assert x.type == y.type + + ZERO = y.type(0) + ONE = y.type(1) + + # NOTE: On x86 at least, dividing the lowest representable integer + # (e.g. 0x80000000 for int32) by -1 causes a SIFGPE (division overflow), + # causing the process to crash. + # We return 0, 0 instead (more or less like Numpy). + + resdiv = cgutils.alloca_once_value(builder, ZERO) + resmod = cgutils.alloca_once_value(builder, ZERO) + + is_overflow = builder.and_( + builder.icmp_signed("==", x, x.type(ty.minval)), + builder.icmp_signed("==", y, y.type(-1)), + ) + + with builder.if_then(builder.not_(is_overflow), likely=True): + # Note LLVM will optimize this to a single divmod instruction, + # if available on the target CPU (e.g. x86). + xdivy = builder.sdiv(x, y) + xmody = builder.srem(x, y) + + y_xor_xmody_ltz = builder.icmp_signed("<", builder.xor(y, xmody), ZERO) + xmody_istrue = builder.icmp_signed("!=", xmody, ZERO) + cond = builder.and_(xmody_istrue, y_xor_xmody_ltz) + + with builder.if_else(cond) as (if_different_signs, if_same_signs): + with if_same_signs: + builder.store(xdivy, resdiv) + builder.store(xmody, resmod) + + with if_different_signs: + builder.store(builder.sub(xdivy, ONE), resdiv) + builder.store(builder.add(xmody, y), resmod) + + return builder.load(resdiv), builder.load(resmod) + + +def int_divmod(context, builder, ty, x, y): + """ + Integer divmod(x, y). The caller must ensure that y != 0. + """ + if ty.signed: + return int_divmod_signed(context, builder, ty, x, y) + else: + return builder.udiv(x, y), builder.urem(x, y) + + +def _int_divmod_impl(context, builder, sig, args, zerodiv_message): + va, vb = args + ta, tb = sig.args + + ty = sig.return_type + if isinstance(ty, types.UniTuple): + ty = ty.dtype + a = context.cast(builder, va, ta, ty) + b = context.cast(builder, vb, tb, ty) + quot = cgutils.alloca_once(builder, a.type, name="quot") + rem = cgutils.alloca_once(builder, a.type, name="rem") + + with builder.if_else(cgutils.is_scalar_zero(builder, b), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, (zerodiv_message,) + ): + # No exception raised => return 0 + # XXX We should also set the FPU exception status, but + # there's no easy way to do that from LLVM. + builder.store(b, quot) + builder.store(b, rem) + with if_non_zero: + q, r = int_divmod(context, builder, ty, a, b) + builder.store(q, quot) + builder.store(r, rem) + + return quot, rem + + +@lower_builtin(divmod, types.Integer, types.Integer) +def int_divmod_impl(context, builder, sig, args): + quot, rem = _int_divmod_impl( + context, builder, sig, args, "integer divmod by zero" + ) + + return cgutils.pack_array(builder, (builder.load(quot), builder.load(rem))) + + +@lower_builtin(operator.floordiv, types.Integer, types.Integer) +@lower_builtin(operator.ifloordiv, types.Integer, types.Integer) +def int_floordiv_impl(context, builder, sig, args): + quot, rem = _int_divmod_impl( + context, builder, sig, args, "integer division by zero" + ) + return builder.load(quot) + + +@lower_builtin(operator.truediv, types.Integer, types.Integer) +@lower_builtin(operator.itruediv, types.Integer, types.Integer) +def int_truediv_impl(context, builder, sig, args): + [va, vb] = args + [ta, tb] = sig.args + a = context.cast(builder, va, ta, sig.return_type) + b = context.cast(builder, vb, tb, sig.return_type) + with cgutils.if_zero(builder, b): + context.error_model.fp_zero_division(builder, ("division by zero",)) + res = builder.fdiv(a, b) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower_builtin(operator.mod, types.Integer, types.Integer) +@lower_builtin(operator.imod, types.Integer, types.Integer) +def int_rem_impl(context, builder, sig, args): + quot, rem = _int_divmod_impl( + context, builder, sig, args, "integer modulo by zero" + ) + return builder.load(rem) + + +def _get_power_zerodiv_return(context, return_type): + if ( + isinstance(return_type, types.Integer) + and not context.error_model.raise_on_fp_zero_division + ): + # If not raising, return 0x8000... when computing 0 ** + return -1 << (return_type.bitwidth - 1) + else: + return False + + +def int_power_impl(context, builder, sig, args): + """ + a ^ b, where a is an integer or real, and b an integer + """ + is_integer = isinstance(sig.args[0], types.Integer) + tp = sig.return_type + zerodiv_return = _get_power_zerodiv_return(context, tp) + + def int_power(a, b): + # Ensure computations are done with a large enough width + r = tp(1) + a = tp(a) + if b < 0: + invert = True + exp = -b + if exp < 0: + raise OverflowError + if is_integer: + if a == 0: + if zerodiv_return: + return zerodiv_return + else: + raise ZeroDivisionError( + "0 cannot be raised to a negative power" + ) + if a != 1 and a != -1: + return 0 + else: + invert = False + exp = b + if exp > 0x10000: + # Optimization cutoff: fallback on the generic algorithm + return math.pow(a, float(b)) + while exp != 0: + if exp & 1: + r *= a + exp >>= 1 + a *= a + + return 1.0 / r if invert else r + + res = context.compile_internal(builder, int_power, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower_builtin(operator.pow, types.Integer, types.IntegerLiteral) +@lower_builtin(operator.ipow, types.Integer, types.IntegerLiteral) +@lower_builtin(operator.pow, types.Float, types.IntegerLiteral) +@lower_builtin(operator.ipow, types.Float, types.IntegerLiteral) +def static_power_impl(context, builder, sig, args): + """ + a ^ b, where a is an integer or real, and b a constant integer + """ + exp = sig.args[1].value + if not isinstance(exp, numbers.Integral): + raise NotImplementedError + if abs(exp) > 0x10000: + # Optimization cutoff: fallback on the generic algorithm above + raise NotImplementedError + invert = exp < 0 + exp = abs(exp) + + tp = sig.return_type + is_integer = isinstance(tp, types.Integer) + zerodiv_return = _get_power_zerodiv_return(context, tp) + + val = context.cast(builder, args[0], sig.args[0], tp) + lty = val.type + + def mul(a, b): + if is_integer: + return builder.mul(a, b) + else: + return builder.fmul(a, b) + + # Unroll the exponentiation loop + res = lty(1) + while exp != 0: + if exp & 1: + res = mul(res, val) + exp >>= 1 + val = mul(val, val) + + if invert: + # If the exponent was negative, fix the result by inverting it + if is_integer: + # Integer inversion + def invert_impl(a): + if a == 0: + if zerodiv_return: + return zerodiv_return + else: + raise ZeroDivisionError( + "0 cannot be raised to a negative power" + ) + if a != 1 and a != -1: + return 0 + else: + return a + + else: + # Real inversion + def invert_impl(a): + return 1.0 / a + + res = context.compile_internal( + builder, invert_impl, typing.signature(tp, tp), (res,) + ) + + return res + + +def int_slt_impl(context, builder, sig, args): + res = builder.icmp_signed("<", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sle_impl(context, builder, sig, args): + res = builder.icmp_signed("<=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sgt_impl(context, builder, sig, args): + res = builder.icmp_signed(">", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sge_impl(context, builder, sig, args): + res = builder.icmp_signed(">=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ult_impl(context, builder, sig, args): + res = builder.icmp_unsigned("<", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ule_impl(context, builder, sig, args): + res = builder.icmp_unsigned("<=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ugt_impl(context, builder, sig, args): + res = builder.icmp_unsigned(">", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_uge_impl(context, builder, sig, args): + res = builder.icmp_unsigned(">=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_eq_impl(context, builder, sig, args): + res = builder.icmp_unsigned("==", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_ne_impl(context, builder, sig, args): + res = builder.icmp_unsigned("!=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_signed_unsigned_cmp(op): + def impl(context, builder, sig, args): + (left, right) = args + # This code is translated from the NumPy source. + # What we're going to do is divide the range of a signed value at zero. + # If the signed value is less than zero, then we can treat zero as the + # unsigned value since the unsigned value is necessarily zero or larger + # and any signed comparison between a negative value and zero/infinity + # will yield the same result. If the signed value is greater than or + # equal to zero, then we can safely cast it to an unsigned value and do + # the expected unsigned-unsigned comparison operation. + # Original: https://github.com/numpy/numpy/pull/23713 + cmp_zero = builder.icmp_signed("<", left, Constant(left.type, 0)) + lt_zero = builder.icmp_signed(op, left, Constant(left.type, 0)) + ge_zero = builder.icmp_unsigned(op, left, right) + res = builder.select(cmp_zero, lt_zero, ge_zero) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return impl + + +def int_unsigned_signed_cmp(op): + def impl(context, builder, sig, args): + (left, right) = args + # See the function `int_signed_unsigned_cmp` for implementation notes. + cmp_zero = builder.icmp_signed("<", right, Constant(right.type, 0)) + lt_zero = builder.icmp_signed(op, Constant(right.type, 0), right) + ge_zero = builder.icmp_unsigned(op, left, right) + res = builder.select(cmp_zero, lt_zero, ge_zero) + return impl_ret_untracked(context, builder, sig.return_type, res) + + return impl + + +def int_abs_impl(context, builder, sig, args): + [x] = args + ZERO = Constant(x.type, None) + ltz = builder.icmp_signed("<", x, ZERO) + negated = builder.neg(x) + res = builder.select(ltz, negated, x) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def uint_abs_impl(context, builder, sig, args): + [x] = args + return impl_ret_untracked(context, builder, sig.return_type, x) + + +def int_shl_impl(context, builder, sig, args): + [valty, amtty] = sig.args + [val, amt] = args + val = context.cast(builder, val, valty, sig.return_type) + amt = context.cast(builder, amt, amtty, sig.return_type) + res = builder.shl(val, amt) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_shr_impl(context, builder, sig, args): + [valty, amtty] = sig.args + [val, amt] = args + val = context.cast(builder, val, valty, sig.return_type) + amt = context.cast(builder, amt, amtty, sig.return_type) + if sig.return_type.signed: + res = builder.ashr(val, amt) + else: + res = builder.lshr(val, amt) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_and_impl(context, builder, sig, args): + [at, bt] = sig.args + [av, bv] = args + cav = context.cast(builder, av, at, sig.return_type) + cbc = context.cast(builder, bv, bt, sig.return_type) + res = builder.and_(cav, cbc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_or_impl(context, builder, sig, args): + [at, bt] = sig.args + [av, bv] = args + cav = context.cast(builder, av, at, sig.return_type) + cbc = context.cast(builder, bv, bt, sig.return_type) + res = builder.or_(cav, cbc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_xor_impl(context, builder, sig, args): + [at, bt] = sig.args + [av, bv] = args + cav = context.cast(builder, av, at, sig.return_type) + cbc = context.cast(builder, bv, bt, sig.return_type) + res = builder.xor(cav, cbc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_negate_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + # Negate before upcasting, for unsigned numbers + res = builder.neg(val) + res = context.cast(builder, res, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_positive_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_invert_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + # Invert before upcasting, for unsigned numbers + res = builder.xor(val, Constant(val.type, int("1" * val.type.width, 2))) + res = context.cast(builder, res, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def int_sign_impl(context, builder, sig, args): + """ + np.sign(int) + """ + [x] = args + POS = Constant(x.type, 1) + NEG = Constant(x.type, -1) + ZERO = Constant(x.type, 0) + + cmp_zero = builder.icmp_unsigned("==", x, ZERO) + cmp_pos = builder.icmp_signed(">", x, ZERO) + + presult = cgutils.alloca_once(builder, x.type) + + bb_zero = builder.append_basic_block(".zero") + bb_postest = builder.append_basic_block(".postest") + bb_pos = builder.append_basic_block(".pos") + bb_neg = builder.append_basic_block(".neg") + bb_exit = builder.append_basic_block(".exit") + + builder.cbranch(cmp_zero, bb_zero, bb_postest) + + with builder.goto_block(bb_zero): + builder.store(ZERO, presult) + builder.branch(bb_exit) + + with builder.goto_block(bb_postest): + builder.cbranch(cmp_pos, bb_pos, bb_neg) + + with builder.goto_block(bb_pos): + builder.store(POS, presult) + builder.branch(bb_exit) + + with builder.goto_block(bb_neg): + builder.store(NEG, presult) + builder.branch(bb_exit) + + builder.position_at_end(bb_exit) + res = builder.load(presult) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def bool_negate_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + res = builder.neg(res) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def bool_unary_positive_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +lower_builtin(operator.eq, types.boolean, types.boolean)(int_eq_impl) +lower_builtin(operator.ne, types.boolean, types.boolean)(int_ne_impl) +lower_builtin(operator.lt, types.boolean, types.boolean)(int_ult_impl) +lower_builtin(operator.le, types.boolean, types.boolean)(int_ule_impl) +lower_builtin(operator.gt, types.boolean, types.boolean)(int_ugt_impl) +lower_builtin(operator.ge, types.boolean, types.boolean)(int_uge_impl) +lower_builtin(operator.neg, types.boolean)(bool_negate_impl) +lower_builtin(operator.pos, types.boolean)(bool_unary_positive_impl) + + +def _implement_integer_operators(): + ty = types.Integer + + lower_builtin(operator.add, ty, ty)(int_add_impl) + lower_builtin(operator.iadd, ty, ty)(int_add_impl) + lower_builtin(operator.sub, ty, ty)(int_sub_impl) + lower_builtin(operator.isub, ty, ty)(int_sub_impl) + lower_builtin(operator.mul, ty, ty)(int_mul_impl) + lower_builtin(operator.imul, ty, ty)(int_mul_impl) + lower_builtin(operator.eq, ty, ty)(int_eq_impl) + lower_builtin(operator.ne, ty, ty)(int_ne_impl) + + lower_builtin(operator.lshift, ty, ty)(int_shl_impl) + lower_builtin(operator.ilshift, ty, ty)(int_shl_impl) + lower_builtin(operator.rshift, ty, ty)(int_shr_impl) + lower_builtin(operator.irshift, ty, ty)(int_shr_impl) + + lower_builtin(operator.neg, ty)(int_negate_impl) + lower_builtin(operator.pos, ty)(int_positive_impl) + + lower_builtin(operator.pow, ty, ty)(int_power_impl) + lower_builtin(operator.ipow, ty, ty)(int_power_impl) + lower_builtin(pow, ty, ty)(int_power_impl) + + for ty in types.unsigned_domain: + lower_builtin(operator.lt, ty, ty)(int_ult_impl) + lower_builtin(operator.le, ty, ty)(int_ule_impl) + lower_builtin(operator.gt, ty, ty)(int_ugt_impl) + lower_builtin(operator.ge, ty, ty)(int_uge_impl) + lower_builtin(operator.pow, types.Float, ty)(int_power_impl) + lower_builtin(operator.ipow, types.Float, ty)(int_power_impl) + lower_builtin(pow, types.Float, ty)(int_power_impl) + lower_builtin(abs, ty)(uint_abs_impl) + + lower_builtin(operator.lt, types.IntegerLiteral, types.IntegerLiteral)( + int_slt_impl + ) + lower_builtin(operator.gt, types.IntegerLiteral, types.IntegerLiteral)( + int_slt_impl + ) + lower_builtin(operator.le, types.IntegerLiteral, types.IntegerLiteral)( + int_slt_impl + ) + lower_builtin(operator.ge, types.IntegerLiteral, types.IntegerLiteral)( + int_slt_impl + ) + for ty in types.signed_domain: + lower_builtin(operator.lt, ty, ty)(int_slt_impl) + lower_builtin(operator.le, ty, ty)(int_sle_impl) + lower_builtin(operator.gt, ty, ty)(int_sgt_impl) + lower_builtin(operator.ge, ty, ty)(int_sge_impl) + lower_builtin(operator.pow, types.Float, ty)(int_power_impl) + lower_builtin(operator.ipow, types.Float, ty)(int_power_impl) + lower_builtin(pow, types.Float, ty)(int_power_impl) + lower_builtin(abs, ty)(int_abs_impl) + + +def _implement_bitwise_operators(): + for ty in (types.Boolean, types.Integer): + lower_builtin(operator.and_, ty, ty)(int_and_impl) + lower_builtin(operator.iand, ty, ty)(int_and_impl) + lower_builtin(operator.or_, ty, ty)(int_or_impl) + lower_builtin(operator.ior, ty, ty)(int_or_impl) + lower_builtin(operator.xor, ty, ty)(int_xor_impl) + lower_builtin(operator.ixor, ty, ty)(int_xor_impl) + + lower_builtin(operator.invert, ty)(int_invert_impl) + + +_implement_integer_operators() + +_implement_bitwise_operators() + + +def real_add_impl(context, builder, sig, args): + res = builder.fadd(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_sub_impl(context, builder, sig, args): + res = builder.fsub(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_mul_impl(context, builder, sig, args): + res = builder.fmul(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_div_impl(context, builder, sig, args): + with cgutils.if_zero(builder, args[1]): + context.error_model.fp_zero_division(builder, ("division by zero",)) + res = builder.fdiv(*args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_divmod(context, builder, x, y): + assert x.type == y.type + floatty = x.type + + module = builder.module + fname = context.mangler(".numba.python.rem", [x.type]) + fnty = ir.FunctionType(floatty, (floatty, floatty, ir.PointerType(floatty))) + fn = cgutils.get_or_insert_function(module, fnty, fname) + + if fn.is_declaration: + fn.linkage = "linkonce_odr" + fnbuilder = ir.IRBuilder(fn.append_basic_block("entry")) + fx, fy, pmod = fn.args + div, mod = real_divmod_func_body(context, fnbuilder, fx, fy) + fnbuilder.store(mod, pmod) + fnbuilder.ret(div) + + pmod = cgutils.alloca_once(builder, floatty) + quotient = builder.call(fn, (x, y, pmod)) + return quotient, builder.load(pmod) + + +def real_divmod_func_body(context, builder, vx, wx): + # Reference Objects/floatobject.c + # + # float_divmod(PyObject *v, PyObject *w) + # { + # double vx, wx; + # double div, mod, floordiv; + # CONVERT_TO_DOUBLE(v, vx); + # CONVERT_TO_DOUBLE(w, wx); + # mod = fmod(vx, wx); + # /* fmod is typically exact, so vx-mod is *mathematically* an + # exact multiple of wx. But this is fp arithmetic, and fp + # vx - mod is an approximation; the result is that div may + # not be an exact integral value after the division, although + # it will always be very close to one. + # */ + # div = (vx - mod) / wx; + # if (mod) { + # /* ensure the remainder has the same sign as the denominator */ + # if ((wx < 0) != (mod < 0)) { + # mod += wx; + # div -= 1.0; + # } + # } + # else { + # /* the remainder is zero, and in the presence of signed zeroes + # fmod returns different results across platforms; ensure + # it has the same sign as the denominator; we'd like to do + # "mod = wx * 0.0", but that may get optimized away */ + # mod *= mod; /* hide "mod = +0" from optimizer */ + # if (wx < 0.0) + # mod = -mod; + # } + # /* snap quotient to nearest integral value */ + # if (div) { + # floordiv = floor(div); + # if (div - floordiv > 0.5) + # floordiv += 1.0; + # } + # else { + # /* div is zero - get the same sign as the true quotient */ + # div *= div; /* hide "div = +0" from optimizers */ + # floordiv = div * vx / wx; /* zero w/ sign of vx/wx */ + # } + # return Py_BuildValue("(dd)", floordiv, mod); + # } + pmod = cgutils.alloca_once(builder, vx.type) + pdiv = cgutils.alloca_once(builder, vx.type) + pfloordiv = cgutils.alloca_once(builder, vx.type) + + mod = builder.frem(vx, wx) + div = builder.fdiv(builder.fsub(vx, mod), wx) + + builder.store(mod, pmod) + builder.store(div, pdiv) + + # Note the use of negative zero for proper negating with `ZERO - x` + ZERO = vx.type(0.0) + NZERO = vx.type(-0.0) + ONE = vx.type(1.0) + mod_istrue = builder.fcmp_unordered("!=", mod, ZERO) + wx_ltz = builder.fcmp_ordered("<", wx, ZERO) + mod_ltz = builder.fcmp_ordered("<", mod, ZERO) + + with builder.if_else(mod_istrue, likely=True) as ( + if_nonzero_mod, + if_zero_mod, + ): + with if_nonzero_mod: + # `mod` is non-zero or NaN + # Ensure the remainder has the same sign as the denominator + wx_ltz_ne_mod_ltz = builder.icmp_unsigned("!=", wx_ltz, mod_ltz) + + with builder.if_then(wx_ltz_ne_mod_ltz): + builder.store(builder.fsub(div, ONE), pdiv) + builder.store(builder.fadd(mod, wx), pmod) + + with if_zero_mod: + # `mod` is zero, select the proper sign depending on + # the denominator's sign + mod = builder.select(wx_ltz, NZERO, ZERO) + builder.store(mod, pmod) + + del mod, div + + div = builder.load(pdiv) + div_istrue = builder.fcmp_ordered("!=", div, ZERO) + + with builder.if_then(div_istrue): + realtypemap = {"float": types.float32, "double": types.float64} + realtype = realtypemap[str(wx.type)] + floorfn = context.get_function( + math.floor, typing.signature(realtype, realtype) + ) + floordiv = floorfn(builder, [div]) + floordivdiff = builder.fsub(div, floordiv) + floordivincr = builder.fadd(floordiv, ONE) + HALF = Constant(wx.type, 0.5) + pred = builder.fcmp_ordered(">", floordivdiff, HALF) + floordiv = builder.select(pred, floordivincr, floordiv) + builder.store(floordiv, pfloordiv) + + with cgutils.ifnot(builder, div_istrue): + div = builder.fmul(div, div) + builder.store(div, pdiv) + floordiv = builder.fdiv(builder.fmul(div, vx), wx) + builder.store(floordiv, pfloordiv) + + return builder.load(pfloordiv), builder.load(pmod) + + +@lower_builtin(divmod, types.Float, types.Float) +def real_divmod_impl(context, builder, sig, args, loc=None): + x, y = args + quot = cgutils.alloca_once(builder, x.type, name="quot") + rem = cgutils.alloca_once(builder, x.type, name="rem") + + with builder.if_else(cgutils.is_scalar_zero(builder, y), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, ("modulo by zero",), loc + ): + # No exception raised => compute the nan result, + # and set the FP exception word for Numpy warnings. + q = builder.fdiv(x, y) + r = builder.frem(x, y) + builder.store(q, quot) + builder.store(r, rem) + with if_non_zero: + q, r = real_divmod(context, builder, x, y) + builder.store(q, quot) + builder.store(r, rem) + + return cgutils.pack_array(builder, (builder.load(quot), builder.load(rem))) + + +def real_mod_impl(context, builder, sig, args, loc=None): + x, y = args + res = cgutils.alloca_once(builder, x.type) + with builder.if_else(cgutils.is_scalar_zero(builder, y), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, ("modulo by zero",), loc + ): + # No exception raised => compute the nan result, + # and set the FP exception word for Numpy warnings. + rem = builder.frem(x, y) + builder.store(rem, res) + with if_non_zero: + _, rem = real_divmod(context, builder, x, y) + builder.store(rem, res) + return impl_ret_untracked( + context, builder, sig.return_type, builder.load(res) + ) + + +def real_floordiv_impl(context, builder, sig, args, loc=None): + x, y = args + res = cgutils.alloca_once(builder, x.type) + with builder.if_else(cgutils.is_scalar_zero(builder, y), likely=False) as ( + if_zero, + if_non_zero, + ): + with if_zero: + if not context.error_model.fp_zero_division( + builder, ("division by zero",), loc + ): + # No exception raised => compute the +/-inf or nan result, + # and set the FP exception word for Numpy warnings. + quot = builder.fdiv(x, y) + builder.store(quot, res) + with if_non_zero: + quot, _ = real_divmod(context, builder, x, y) + builder.store(quot, res) + return impl_ret_untracked( + context, builder, sig.return_type, builder.load(res) + ) + + +def real_power_impl(context, builder, sig, args): + x, y = args + module = builder.module + if context.implement_powi_as_math_call: + imp = context.get_function(math.pow, sig) + res = imp(builder, args) + else: + fn = module.declare_intrinsic("llvm.pow", [y.type]) + res = builder.call(fn, (x, y)) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_lt_impl(context, builder, sig, args): + res = builder.fcmp_ordered("<", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_le_impl(context, builder, sig, args): + res = builder.fcmp_ordered("<=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_gt_impl(context, builder, sig, args): + res = builder.fcmp_ordered(">", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_ge_impl(context, builder, sig, args): + res = builder.fcmp_ordered(">=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_eq_impl(context, builder, sig, args): + res = builder.fcmp_ordered("==", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_ne_impl(context, builder, sig, args): + res = builder.fcmp_unordered("!=", *args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_abs_impl(context, builder, sig, args): + [ty] = sig.args + sig = typing.signature(ty, ty) + impl = context.get_function(math.fabs, sig) + return impl(builder, args) + + +def real_negate_impl(context, builder, sig, args): + from numba.cuda.cpython import mathimpl + + res = mathimpl.negate_real(builder, args[0]) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_positive_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + res = context.cast(builder, val, typ, sig.return_type) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_sign_impl(context, builder, sig, args): + """ + np.sign(float) + """ + [x] = args + POS = Constant(x.type, 1) + NEG = Constant(x.type, -1) + ZERO = Constant(x.type, 0) + + presult = cgutils.alloca_once(builder, x.type) + + is_pos = builder.fcmp_ordered(">", x, ZERO) + is_neg = builder.fcmp_ordered("<", x, ZERO) + + with builder.if_else(is_pos) as (gt_zero, not_gt_zero): + with gt_zero: + builder.store(POS, presult) + with not_gt_zero: + with builder.if_else(is_neg) as (lt_zero, not_lt_zero): + with lt_zero: + builder.store(NEG, presult) + with not_lt_zero: + # For both NaN and 0, the result of sign() is simply + # the input value. + builder.store(x, presult) + + res = builder.load(presult) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +ty = types.Float + +lower_builtin(operator.add, ty, ty)(real_add_impl) +lower_builtin(operator.iadd, ty, ty)(real_add_impl) +lower_builtin(operator.sub, ty, ty)(real_sub_impl) +lower_builtin(operator.isub, ty, ty)(real_sub_impl) +lower_builtin(operator.mul, ty, ty)(real_mul_impl) +lower_builtin(operator.imul, ty, ty)(real_mul_impl) +lower_builtin(operator.floordiv, ty, ty)(real_floordiv_impl) +lower_builtin(operator.ifloordiv, ty, ty)(real_floordiv_impl) +lower_builtin(operator.truediv, ty, ty)(real_div_impl) +lower_builtin(operator.itruediv, ty, ty)(real_div_impl) +lower_builtin(operator.mod, ty, ty)(real_mod_impl) +lower_builtin(operator.imod, ty, ty)(real_mod_impl) +lower_builtin(operator.pow, ty, ty)(real_power_impl) +lower_builtin(operator.ipow, ty, ty)(real_power_impl) +lower_builtin(pow, ty, ty)(real_power_impl) + +lower_builtin(operator.eq, ty, ty)(real_eq_impl) +lower_builtin(operator.ne, ty, ty)(real_ne_impl) +lower_builtin(operator.lt, ty, ty)(real_lt_impl) +lower_builtin(operator.le, ty, ty)(real_le_impl) +lower_builtin(operator.gt, ty, ty)(real_gt_impl) +lower_builtin(operator.ge, ty, ty)(real_ge_impl) + +lower_builtin(abs, ty)(real_abs_impl) + +lower_builtin(operator.neg, ty)(real_negate_impl) +lower_builtin(operator.pos, ty)(real_positive_impl) + +del ty + + +@lower_getattr(types.Complex, "real") +def complex_real_impl(context, builder, typ, value): + cplx = context.make_complex(builder, typ, value=value) + res = cplx.real + return impl_ret_untracked(context, builder, typ, res) + + +@lower_getattr(types.Complex, "imag") +def complex_imag_impl(context, builder, typ, value): + cplx = context.make_complex(builder, typ, value=value) + res = cplx.imag + return impl_ret_untracked(context, builder, typ, res) + + +@lower_builtin("complex.conjugate", types.Complex) +def complex_conjugate_impl(context, builder, sig, args): + from numba.cuda.cpython import mathimpl + + z = context.make_complex(builder, sig.args[0], args[0]) + z.imag = mathimpl.negate_real(builder, z.imag) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def real_real_impl(context, builder, typ, value): + return impl_ret_untracked(context, builder, typ, value) + + +def real_imag_impl(context, builder, typ, value): + res = cgutils.get_null_value(value.type) + return impl_ret_untracked(context, builder, typ, res) + + +def real_conjugate_impl(context, builder, sig, args): + return impl_ret_untracked(context, builder, sig.return_type, args[0]) + + +for cls in (types.Float, types.Integer): + lower_getattr(cls, "real")(real_real_impl) + lower_getattr(cls, "imag")(real_imag_impl) + lower_builtin("complex.conjugate", cls)(real_conjugate_impl) + + +@lower_builtin(operator.pow, types.Complex, types.Complex) +@lower_builtin(operator.ipow, types.Complex, types.Complex) +@lower_builtin(pow, types.Complex, types.Complex) +def complex_power_impl(context, builder, sig, args): + [ca, cb] = args + ty = sig.args[0] + fty = ty.underlying_float + a = context.make_helper(builder, ty, value=ca) + b = context.make_helper(builder, ty, value=cb) + c = context.make_helper(builder, ty) + module = builder.module + pa = a._getpointer() + pb = b._getpointer() + pc = c._getpointer() + + # Optimize for square because cpow loses a lot of precision + TWO = context.get_constant(fty, 2) + ZERO = context.get_constant(fty, 0) + + b_real_is_two = builder.fcmp_ordered("==", b.real, TWO) + b_imag_is_zero = builder.fcmp_ordered("==", b.imag, ZERO) + b_is_two = builder.and_(b_real_is_two, b_imag_is_zero) + + with builder.if_else(b_is_two) as (then, otherwise): + with then: + # Lower as multiplication + res = complex_mul_impl(context, builder, sig, (ca, ca)) + cres = context.make_helper(builder, ty, value=res) + c.real = cres.real + c.imag = cres.imag + + with otherwise: + # Lower with call to external function + func_name = { + types.complex64: "numba_cpowf", + types.complex128: "numba_cpow", + }[ty] + fnty = ir.FunctionType(ir.VoidType(), [pa.type] * 3) + cpow = cgutils.get_or_insert_function(module, fnty, func_name) + builder.call(cpow, (pa, pb, pc)) + + res = builder.load(pc) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_add_impl(context, builder, sig, args): + [cx, cy] = args + ty = sig.args[0] + x = context.make_complex(builder, ty, value=cx) + y = context.make_complex(builder, ty, value=cy) + z = context.make_complex(builder, ty) + a = x.real + b = x.imag + c = y.real + d = y.imag + z.real = builder.fadd(a, c) + z.imag = builder.fadd(b, d) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_sub_impl(context, builder, sig, args): + [cx, cy] = args + ty = sig.args[0] + x = context.make_complex(builder, ty, value=cx) + y = context.make_complex(builder, ty, value=cy) + z = context.make_complex(builder, ty) + a = x.real + b = x.imag + c = y.real + d = y.imag + z.real = builder.fsub(a, c) + z.imag = builder.fsub(b, d) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_mul_impl(context, builder, sig, args): + """ + (a+bi)(c+di)=(ac-bd)+i(ad+bc) + """ + [cx, cy] = args + ty = sig.args[0] + x = context.make_complex(builder, ty, value=cx) + y = context.make_complex(builder, ty, value=cy) + z = context.make_complex(builder, ty) + a = x.real + b = x.imag + c = y.real + d = y.imag + ac = builder.fmul(a, c) + bd = builder.fmul(b, d) + ad = builder.fmul(a, d) + bc = builder.fmul(b, c) + z.real = builder.fsub(ac, bd) + z.imag = builder.fadd(ad, bc) + res = z._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +NAN = float("nan") + + +def complex_div_impl(context, builder, sig, args): + def complex_div(a, b): + # This is CPython's algorithm (in _Py_c_quot()). + areal = a.real + aimag = a.imag + breal = b.real + bimag = b.imag + if not breal and not bimag: + raise ZeroDivisionError("complex division by zero") + if abs(breal) >= abs(bimag): + # Divide tops and bottom by b.real + if not breal: + return complex(NAN, NAN) + ratio = bimag / breal + denom = breal + bimag * ratio + return complex( + (areal + aimag * ratio) / denom, (aimag - areal * ratio) / denom + ) + else: + # Divide tops and bottom by b.imag + if not bimag: + return complex(NAN, NAN) + ratio = breal / bimag + denom = breal * ratio + bimag + return complex( + (a.real * ratio + a.imag) / denom, + (a.imag * ratio - a.real) / denom, + ) + + res = context.compile_internal(builder, complex_div, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_negate_impl(context, builder, sig, args): + from numba.cpython import mathimpl + + [typ] = sig.args + [val] = args + cmplx = context.make_complex(builder, typ, value=val) + res = context.make_complex(builder, typ) + res.real = mathimpl.negate_real(builder, cmplx.real) + res.imag = mathimpl.negate_real(builder, cmplx.imag) + res = res._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_positive_impl(context, builder, sig, args): + [val] = args + return impl_ret_untracked(context, builder, sig.return_type, val) + + +def complex_eq_impl(context, builder, sig, args): + [cx, cy] = args + typ = sig.args[0] + x = context.make_complex(builder, typ, value=cx) + y = context.make_complex(builder, typ, value=cy) + + reals_are_eq = builder.fcmp_ordered("==", x.real, y.real) + imags_are_eq = builder.fcmp_ordered("==", x.imag, y.imag) + res = builder.and_(reals_are_eq, imags_are_eq) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_ne_impl(context, builder, sig, args): + [cx, cy] = args + typ = sig.args[0] + x = context.make_complex(builder, typ, value=cx) + y = context.make_complex(builder, typ, value=cy) + + reals_are_ne = builder.fcmp_unordered("!=", x.real, y.real) + imags_are_ne = builder.fcmp_unordered("!=", x.imag, y.imag) + res = builder.or_(reals_are_ne, imags_are_ne) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +def complex_abs_impl(context, builder, sig, args): + """ + abs(z) := hypot(z.real, z.imag) + """ + + def complex_abs(z): + return math.hypot(z.real, z.imag) + + res = context.compile_internal(builder, complex_abs, sig, args) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +ty = types.Complex + +lower_builtin(operator.add, ty, ty)(complex_add_impl) +lower_builtin(operator.iadd, ty, ty)(complex_add_impl) +lower_builtin(operator.sub, ty, ty)(complex_sub_impl) +lower_builtin(operator.isub, ty, ty)(complex_sub_impl) +lower_builtin(operator.mul, ty, ty)(complex_mul_impl) +lower_builtin(operator.imul, ty, ty)(complex_mul_impl) +lower_builtin(operator.truediv, ty, ty)(complex_div_impl) +lower_builtin(operator.itruediv, ty, ty)(complex_div_impl) +lower_builtin(operator.neg, ty)(complex_negate_impl) +lower_builtin(operator.pos, ty)(complex_positive_impl) +# Complex modulo is deprecated in python3 + +lower_builtin(operator.eq, ty, ty)(complex_eq_impl) +lower_builtin(operator.ne, ty, ty)(complex_ne_impl) + +lower_builtin(abs, ty)(complex_abs_impl) + +del ty + + +@lower_builtin("number.item", types.Boolean) +@lower_builtin("number.item", types.Number) +def number_item_impl(context, builder, sig, args): + """ + The no-op .item() method on booleans and numbers. + """ + return args[0] + + +# ------------------------------------------------------------------------------ + + +def number_not_impl(context, builder, sig, args): + [typ] = sig.args + [val] = args + istrue = context.cast(builder, val, typ, sig.return_type) + res = builder.not_(istrue) + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower_builtin(bool, types.Boolean) +def bool_as_bool(context, builder, sig, args): + [val] = args + return val + + +@lower_builtin(bool, types.Integer) +def int_as_bool(context, builder, sig, args): + [val] = args + return builder.icmp_unsigned("!=", val, Constant(val.type, 0)) + + +@lower_builtin(bool, types.Float) +def float_as_bool(context, builder, sig, args): + [val] = args + return builder.fcmp_unordered("!=", val, Constant(val.type, 0.0)) + + +@lower_builtin(bool, types.Complex) +def complex_as_bool(context, builder, sig, args): + [typ] = sig.args + [val] = args + cmplx = context.make_complex(builder, typ, val) + real, imag = cmplx.real, cmplx.imag + zero = Constant(real.type, 0.0) + real_istrue = builder.fcmp_unordered("!=", real, zero) + imag_istrue = builder.fcmp_unordered("!=", imag, zero) + return builder.or_(real_istrue, imag_istrue) + + +for ty in (types.Integer, types.Float, types.Complex): + lower_builtin(operator.not_, ty)(number_not_impl) + +lower_builtin(operator.not_, types.boolean)(number_not_impl) + + +# ------------------------------------------------------------------------------ +# Hashing numbers, see hashing.py + +# ------------------------------------------------------------------------------- +# Implicit casts between numerics + + +@lower_cast(types.IntegerLiteral, types.Integer) +@lower_cast(types.IntegerLiteral, types.Float) +@lower_cast(types.IntegerLiteral, types.Complex) +def literal_int_to_number(context, builder, fromty, toty, val): + lit = context.get_constant_generic( + builder, + fromty.literal_type, + fromty.literal_value, + ) + return context.cast(builder, lit, fromty.literal_type, toty) + + +@lower_cast(types.Integer, types.Integer) +def integer_to_integer(context, builder, fromty, toty, val): + if toty.bitwidth == fromty.bitwidth: + # Just a change of signedness + return val + elif toty.bitwidth < fromty.bitwidth: + # Downcast + return builder.trunc(val, context.get_value_type(toty)) + elif fromty.signed: + # Signed upcast + return builder.sext(val, context.get_value_type(toty)) + else: + # Unsigned upcast + return builder.zext(val, context.get_value_type(toty)) + + +@lower_cast(types.Integer, types.voidptr) +def integer_to_voidptr(context, builder, fromty, toty, val): + return builder.inttoptr(val, context.get_value_type(toty)) + + +@lower_cast(types.Float, types.Float) +def float_to_float(context, builder, fromty, toty, val): + lty = context.get_value_type(toty) + if fromty.bitwidth < toty.bitwidth: + return builder.fpext(val, lty) + else: + return builder.fptrunc(val, lty) + + +@lower_cast(types.Integer, types.Float) +def integer_to_float(context, builder, fromty, toty, val): + lty = context.get_value_type(toty) + if fromty.signed: + return builder.sitofp(val, lty) + else: + return builder.uitofp(val, lty) + + +@lower_cast(types.Float, types.Integer) +def float_to_integer(context, builder, fromty, toty, val): + lty = context.get_value_type(toty) + if toty.signed: + return builder.fptosi(val, lty) + else: + return builder.fptoui(val, lty) + + +@lower_cast(types.Float, types.Complex) +@lower_cast(types.Integer, types.Complex) +def non_complex_to_complex(context, builder, fromty, toty, val): + real = context.cast(builder, val, fromty, toty.underlying_float) + imag = context.get_constant(toty.underlying_float, 0) + + cmplx = context.make_complex(builder, toty) + cmplx.real = real + cmplx.imag = imag + return cmplx._getvalue() + + +@lower_cast(types.Complex, types.Complex) +def complex_to_complex(context, builder, fromty, toty, val): + srcty = fromty.underlying_float + dstty = toty.underlying_float + + src = context.make_complex(builder, fromty, value=val) + dst = context.make_complex(builder, toty) + dst.real = context.cast(builder, src.real, srcty, dstty) + dst.imag = context.cast(builder, src.imag, srcty, dstty) + return dst._getvalue() + + +@lower_cast(types.Any, types.Boolean) +def any_to_boolean(context, builder, fromty, toty, val): + return context.is_true(builder, fromty, val) + + +@lower_cast(types.Boolean, types.Number) +def boolean_to_any(context, builder, fromty, toty, val): + # Casting from boolean to anything first casts to int32 + asint = builder.zext(val, ir.IntType(32)) + return context.cast(builder, asint, types.int32, toty) + + +@lower_cast(types.IntegerLiteral, types.Boolean) +@lower_cast(types.BooleanLiteral, types.Boolean) +def literal_int_to_boolean(context, builder, fromty, toty, val): + lit = context.get_constant_generic( + builder, + fromty.literal_type, + fromty.literal_value, + ) + return context.is_true(builder, fromty.literal_type, lit) + + +# ------------------------------------------------------------------------------- +# Constants + + +@lower_constant(types.Complex) +def constant_complex(context, builder, ty, pyval): + fty = ty.underlying_float + real = context.get_constant_generic(builder, fty, pyval.real) + imag = context.get_constant_generic(builder, fty, pyval.imag) + return Constant.literal_struct((real, imag)) + + +@lower_constant(types.Integer) +@lower_constant(types.Float) +@lower_constant(types.Boolean) +def constant_integer(context, builder, ty, pyval): + # See https://github.com/numba/numba/issues/6979 + # llvmlite ir.IntType specialises the formatting of the constant for a + # cpython bool. A NumPy np.bool_ is not a cpython bool so force it to be one + # so that the constant renders correctly! + if isinstance(pyval, np.bool_): + pyval = bool(pyval) + lty = context.get_value_type(ty) + return lty(pyval) + + +# ------------------------------------------------------------------------------- +# View + + +def scalar_view(scalar, viewty): + """Typing for the np scalar 'view' method.""" + if isinstance(scalar, (types.Float, types.Integer)) and isinstance( + viewty, types.abstract.DTypeSpec + ): + if scalar.bitwidth != viewty.dtype.bitwidth: + raise errors.TypingError( + "Changing the dtype of a 0d array is only supported if the " + "itemsize is unchanged" + ) + + def impl(scalar, viewty): + return viewer(scalar, viewty) + + return impl + + +overload_method(types.Float, "view")(scalar_view) +overload_method(types.Integer, "view")(scalar_view) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 1ee2c5be6..461f182c2 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -148,12 +148,12 @@ def init(self): self._target_data = None def load_additional_registries(self): - # side effect of import needed for numba.cpython.*, the builtins + # side effect of import needed for numba.cpython.*, numba.cuda.cpython.*, the builtins # registry is updated at import time. from numba.cpython import numbers, tupleobj, slicing # noqa: F401 from numba.cpython import rangeobj, iterators, enumimpl # noqa: F401 from numba.cpython import unicode, charseq # noqa: F401 - from numba.cpython import cmathimpl + from numba.cuda.cpython import cmathimpl from numba.misc import cffiimpl from numba.np import arrayobj # noqa: F401 from numba.np import npdatetime # noqa: F401 diff --git a/numba_cuda/numba/cuda/tests/nocuda/test_import.py b/numba_cuda/numba/cuda/tests/nocuda/test_import.py index be1f55a5d..d6724c3cc 100644 --- a/numba_cuda/numba/cuda/tests/nocuda/test_import.py +++ b/numba_cuda/numba/cuda/tests/nocuda/test_import.py @@ -26,6 +26,9 @@ def test_no_impl_import(self): "numba.cpython.mathimpl", "numba.cpython.printimpl", "numba.cpython.randomimpl", + "numba.cuda.cpython.numbers", + "numba.cuda.cpython.cmathimpl", + "numba.cuda.cpython.mathimpl", "numba.core.optional", "numba.misc.gdb_hook", "numba.misc.literal", diff --git a/numba_cuda/numba/cuda/ufuncs.py b/numba_cuda/numba/cuda/ufuncs.py index 6b28a4a3c..d70edf6e2 100644 --- a/numba_cuda/numba/cuda/ufuncs.py +++ b/numba_cuda/numba/cuda/ufuncs.py @@ -26,7 +26,7 @@ def get_ufunc_info(ufunc_key): @lru_cache def ufunc_db(): # Imports here are at function scope to avoid circular imports - from numba.cpython import cmathimpl, mathimpl, numbers + from numba.cuda.cpython import cmathimpl, mathimpl, numbers from numba.np import npyfuncs from numba.cuda.np.numpy_support import numpy_version From 84cbc9d12f62b35fcfe5c9f3fa991fe1cfd420e3 Mon Sep 17 00:00:00 2001 From: Atmn Patel Date: Tue, 5 Aug 2025 18:45:12 -0700 Subject: [PATCH 02/10] remove newly duplicated registries + leave comments for future reintroduction We're using the `cudaimpl` registry for all of these for now. --- numba_cuda/numba/cuda/cpython/cmathimpl.py | 18 ++++++++++++++---- numba_cuda/numba/cuda/cpython/mathimpl.py | 22 +++++++++++++++------- numba_cuda/numba/cuda/target.py | 2 +- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py index 7af62b08e..c1541c0dd 100644 --- a/numba_cuda/numba/cuda/cpython/cmathimpl.py +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -5,14 +5,24 @@ import cmath import math -from numba.core.imputils import Registry, impl_ret_untracked +from numba.core.imputils import impl_ret_untracked from numba.core import types from numba.core.typing import signature from numba.cuda.cpython import mathimpl from numba.core.extending import overload - -registry = Registry("cmathimpl") -lower = registry.lower +from numba.cuda.cudaimpl import lower + +# --------------------------------------------------------------------- +# XXX: In upstream Numba, this file would create a cmathimpl registry +# if it was installed in the target (as it is for the CUDA target). +# The cmathimpl registry has been removed from this file (it was +# initialized as `registry = Registry('cmathimpl')`) as it would duplicate +# the cmathimpl registry in upstream Numba, which would be likely to lead +# to confusion / mixing things up between two cmathimpl registries. The +# comment that accompanied this behaviour is left here, even though the +# code that would pick the cmathimpl registry has been removed, for the +# benefit of future understanding. +# def is_nan(builder, z): diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py index b996f816d..35e9fc6ca 100644 --- a/numba_cuda/numba/cuda/cpython/mathimpl.py +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -9,18 +9,26 @@ import llvmlite.ir from llvmlite.ir import Constant - -from numba.core.imputils import Registry, impl_ret_untracked +from numba.core.imputils import impl_ret_untracked from numba.core import types, config from numba.core.extending import overload from numba.core.typing import signature from numba.cpython.unsafe.numbers import trailing_zeros from numba.cuda import cgutils - - -registry = Registry("mathimpl") -lower = registry.lower - +from numba.cuda.cudaimpl import lower + + +# --------------------------------------------------------------------- +# XXX: In upstream Numba, this file would create a mathimpl registry +# if it was installed in the target (as it is for the CUDA target). +# The mathimpl registry has been removed from this file (it was +# initialized as `registry = Registry('mathimpl')`) as it would duplicate +# the mathimpl registry in upstream Numba, which would be likely to lead +# to confusion / mixing things up between two mathimpl registries. The +# comment that accompanied this behaviour is left here, even though the +# code that would pick the mathimpl registry has been removed, for the +# benefit of future understanding. +# # Helpers, shared with cmathimpl. _NP_FLT_FINFO = np.finfo(np.dtype("float32")) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 461f182c2..3e22b6a9d 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -153,7 +153,7 @@ def load_additional_registries(self): from numba.cpython import numbers, tupleobj, slicing # noqa: F401 from numba.cpython import rangeobj, iterators, enumimpl # noqa: F401 from numba.cpython import unicode, charseq # noqa: F401 - from numba.cuda.cpython import cmathimpl + from numba.cpython import cmathimpl from numba.misc import cffiimpl from numba.np import arrayobj # noqa: F401 from numba.np import npdatetime # noqa: F401 From b72fc43e08c6941926de67fd5a034d237d4e0b80 Mon Sep 17 00:00:00 2001 From: Atmn Patel Date: Thu, 21 Aug 2025 08:32:02 -0700 Subject: [PATCH 03/10] Adds SPDX identifiers to newly added files --- numba_cuda/numba/cuda/cpython/cmathimpl.py | 3 +++ numba_cuda/numba/cuda/cpython/mathimpl.py | 3 +++ numba_cuda/numba/cuda/cpython/numbers.py | 3 +++ 3 files changed, 9 insertions(+) diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py index c1541c0dd..e53f12296 100644 --- a/numba_cuda/numba/cuda/cpython/cmathimpl.py +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + """ Implement the cmath module functions. """ diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py index 35e9fc6ca..9de977081 100644 --- a/numba_cuda/numba/cuda/cpython/mathimpl.py +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + """ Provide math calls that uses intrinsics or libc math functions. """ diff --git a/numba_cuda/numba/cuda/cpython/numbers.py b/numba_cuda/numba/cuda/cpython/numbers.py index 4d4621e48..16b8f13c8 100644 --- a/numba_cuda/numba/cuda/cpython/numbers.py +++ b/numba_cuda/numba/cuda/cpython/numbers.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + import math import numbers From c0a2a92d0f8314e883bc168d4d2325b3e1cb3a5f Mon Sep 17 00:00:00 2001 From: Atmn Patel Date: Thu, 21 Aug 2025 08:32:20 -0700 Subject: [PATCH 04/10] Updates cmathimpl and mathimpl overloads to be CUDA-target specific --- numba_cuda/numba/cuda/cpython/cmathimpl.py | 16 ++++++++-------- numba_cuda/numba/cuda/cpython/mathimpl.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py index e53f12296..3adb36adf 100644 --- a/numba_cuda/numba/cuda/cpython/cmathimpl.py +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -71,7 +71,7 @@ def isfinite_float_impl(context, builder, sig, args): return impl_ret_untracked(context, builder, sig.return_type, res) -@overload(cmath.rect) +@overload(cmath.rect, target="cuda") def impl_cmath_rect(r, phi): if all([isinstance(typ, types.Float) for typ in [r, phi]]): @@ -190,7 +190,7 @@ def log_base(z, base): return impl_ret_untracked(context, builder, sig, res) -@overload(cmath.log10) +@overload(cmath.log10, target="cuda") def impl_cmath_log10(z): if not isinstance(z, types.Complex): return @@ -207,7 +207,7 @@ def log10_impl(z): return log10_impl -@overload(cmath.phase) +@overload(cmath.phase, target="cuda") def phase_impl(x): """cmath.phase(x + y j)""" @@ -220,7 +220,7 @@ def impl(x): return impl -@overload(cmath.polar) +@overload(cmath.polar, target="cuda") def polar_impl(x): if not isinstance(x, types.Complex): return @@ -303,7 +303,7 @@ def cos_impl(z): return impl_ret_untracked(context, builder, sig, res) -@overload(cmath.cosh) +@overload(cmath.cosh, target="cuda") def impl_cmath_cosh(z): if not isinstance(z, types.Complex): return @@ -344,7 +344,7 @@ def sin_impl(z): return impl_ret_untracked(context, builder, sig, res) -@overload(cmath.sinh) +@overload(cmath.sinh, target="cuda") def impl_cmath_sinh(z): if not isinstance(z, types.Complex): return @@ -382,7 +382,7 @@ def tan_impl(z): return impl_ret_untracked(context, builder, sig, res) -@overload(cmath.tanh) +@overload(cmath.tanh, target="cuda") def impl_cmath_tanh(z): if not isinstance(z, types.Complex): return @@ -437,7 +437,7 @@ def acos_impl(z): return impl_ret_untracked(context, builder, sig, res) -@overload(cmath.acosh) +@overload(cmath.acosh, target="cuda") def impl_cmath_acosh(z): if not isinstance(z, types.Complex): return diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py index 9de977081..d8441a733 100644 --- a/numba_cuda/numba/cuda/cpython/mathimpl.py +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -464,7 +464,7 @@ def _unsigned(T): pass -@overload(_unsigned) +@overload(_unsigned, target="cuda") def _unsigned_impl(T): if T in types.unsigned_domain: return lambda T: T From 921006287bd589817f85d3aa2947836ba3412fd4 Mon Sep 17 00:00:00 2001 From: Atmn Patel Date: Tue, 26 Aug 2025 14:17:18 -0700 Subject: [PATCH 05/10] adds registries for cmathimpl, mathimpl, and ensures we use the cudaimpl registry for numbers --- numba_cuda/numba/cuda/cpython/cmathimpl.py | 20 ++++++-------------- numba_cuda/numba/cuda/cpython/mathimpl.py | 21 ++++++--------------- numba_cuda/numba/cuda/cpython/numbers.py | 10 ++++------ numba_cuda/numba/cuda/cudaimpl.py | 11 +++++++++++ numba_cuda/numba/cuda/target.py | 5 +++-- 5 files changed, 30 insertions(+), 37 deletions(-) diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py index 3adb36adf..a4ca8e132 100644 --- a/numba_cuda/numba/cuda/cpython/cmathimpl.py +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -8,24 +8,16 @@ import cmath import math -from numba.core.imputils import impl_ret_untracked +from numba.core.imputils import Registry +from numba.cuda.cudaimpl import impl_ret_untracked from numba.core import types from numba.core.typing import signature from numba.cuda.cpython import mathimpl from numba.core.extending import overload -from numba.cuda.cudaimpl import lower - -# --------------------------------------------------------------------- -# XXX: In upstream Numba, this file would create a cmathimpl registry -# if it was installed in the target (as it is for the CUDA target). -# The cmathimpl registry has been removed from this file (it was -# initialized as `registry = Registry('cmathimpl')`) as it would duplicate -# the cmathimpl registry in upstream Numba, which would be likely to lead -# to confusion / mixing things up between two cmathimpl registries. The -# comment that accompanied this behaviour is left here, even though the -# code that would pick the cmathimpl registry has been removed, for the -# benefit of future understanding. -# + + +registry = Registry("cmathimpl") +lower = registry.lower def is_nan(builder, z): diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py index d8441a733..f5654c5a1 100644 --- a/numba_cuda/numba/cuda/cpython/mathimpl.py +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -12,26 +12,17 @@ import llvmlite.ir from llvmlite.ir import Constant -from numba.core.imputils import impl_ret_untracked + +from numba.core.imputils import Registry +from numba.cuda.cudaimpl import impl_ret_untracked from numba.core import types, config from numba.core.extending import overload from numba.core.typing import signature from numba.cpython.unsafe.numbers import trailing_zeros from numba.cuda import cgutils -from numba.cuda.cudaimpl import lower - - -# --------------------------------------------------------------------- -# XXX: In upstream Numba, this file would create a mathimpl registry -# if it was installed in the target (as it is for the CUDA target). -# The mathimpl registry has been removed from this file (it was -# initialized as `registry = Registry('mathimpl')`) as it would duplicate -# the mathimpl registry in upstream Numba, which would be likely to lead -# to confusion / mixing things up between two mathimpl registries. The -# comment that accompanied this behaviour is left here, even though the -# code that would pick the mathimpl registry has been removed, for the -# benefit of future understanding. -# + +registry = Registry("mathimpl") +lower = registry.lower # Helpers, shared with cmathimpl. _NP_FLT_FINFO = np.finfo(np.dtype("float32")) diff --git a/numba_cuda/numba/cuda/cpython/numbers.py b/numba_cuda/numba/cuda/cpython/numbers.py index 16b8f13c8..b03df3b0e 100644 --- a/numba_cuda/numba/cuda/cpython/numbers.py +++ b/numba_cuda/numba/cuda/cpython/numbers.py @@ -10,12 +10,10 @@ from llvmlite import ir from llvmlite.ir import Constant -from numba.core.imputils import ( - lower_builtin, - lower_getattr, - lower_cast, - lower_constant, - impl_ret_untracked, +from numba.cuda.cudaimpl import lower_cast, lower_constant, impl_ret_untracked +from numba.cuda.cudaimpl import ( + lower as lower_builtin, + lower_attr as lower_getattr, ) from numba.core import typing, types, errors from numba.core.extending import overload_method diff --git a/numba_cuda/numba/cuda/cudaimpl.py b/numba_cuda/numba/cuda/cudaimpl.py index 23c5d9045..e15ddda91 100644 --- a/numba_cuda/numba/cuda/cudaimpl.py +++ b/numba_cuda/numba/cuda/cudaimpl.py @@ -25,6 +25,10 @@ lower = registry.lower lower_attr = registry.lower_getattr lower_constant = registry.lower_constant +lower_getattr_generic = registry.lower_getattr_generic +lower_setattr = registry.lower_setattr +lower_setattr_generic = registry.lower_setattr_generic +lower_cast = registry.lower_cast def initialize_dim3(builder, prefix): @@ -1003,3 +1007,10 @@ def cuda_dispatcher_const(context, builder, ty, pyval): # NumPy register_ufuncs(ufunc_db.get_ufuncs(), lower) + + +def impl_ret_untracked(ctx, builder, retty, ret): + """ + The return type is not a NRT object. + """ + return ret diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 3e22b6a9d..7f7ece4d1 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -150,10 +150,11 @@ def init(self): def load_additional_registries(self): # side effect of import needed for numba.cpython.*, numba.cuda.cpython.*, the builtins # registry is updated at import time. - from numba.cpython import numbers, tupleobj, slicing # noqa: F401 + from numba.cpython import tupleobj, slicing # noqa: F401 + from numba.cuda.cpython import numbers # noqa: F401 from numba.cpython import rangeobj, iterators, enumimpl # noqa: F401 from numba.cpython import unicode, charseq # noqa: F401 - from numba.cpython import cmathimpl + from numba.cuda.cpython import cmathimpl from numba.misc import cffiimpl from numba.np import arrayobj # noqa: F401 from numba.np import npdatetime # noqa: F401 From d36acf4520999c9cd5088170bff6a5889271132d Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 27 Aug 2025 15:35:15 +0100 Subject: [PATCH 06/10] Partial fixes, committing to switch branch --- numba_cuda/numba/cuda/core/typed_passes.py | 4 ++-- numba_cuda/numba/cuda/cpython/cmathimpl.py | 3 +-- numba_cuda/numba/cuda/cpython/mathimpl.py | 3 +-- numba_cuda/numba/cuda/cpython/numbers.py | 12 +++++++----- numba_cuda/numba/cuda/target.py | 6 ++++-- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/numba_cuda/numba/cuda/core/typed_passes.py b/numba_cuda/numba/cuda/core/typed_passes.py index b56497f5a..64b446646 100644 --- a/numba_cuda/numba/cuda/core/typed_passes.py +++ b/numba_cuda/numba/cuda/core/typed_passes.py @@ -325,7 +325,7 @@ def run_pass(self, state): calltypes = state.calltypes flags = state.flags metadata = state.metadata - pre_stats = llvm.passmanagers.dump_refprune_stats() + pre_stats = llvm.newpassmanagers.dump_refprune_stats() msg = "Function %s failed at nopython mode lowering" % ( state.func_id.func_name, @@ -388,7 +388,7 @@ def run_pass(self, state): ) # capture pruning stats - post_stats = llvm.passmanagers.dump_refprune_stats() + post_stats = llvm.newpassmanagers.dump_refprune_stats() metadata["prune_stats"] = post_stats - pre_stats # Save the LLVM pass timings diff --git a/numba_cuda/numba/cuda/cpython/cmathimpl.py b/numba_cuda/numba/cuda/cpython/cmathimpl.py index a4ca8e132..9c6338f19 100644 --- a/numba_cuda/numba/cuda/cpython/cmathimpl.py +++ b/numba_cuda/numba/cuda/cpython/cmathimpl.py @@ -8,8 +8,7 @@ import cmath import math -from numba.core.imputils import Registry -from numba.cuda.cudaimpl import impl_ret_untracked +from numba.core.imputils import impl_ret_untracked, Registry from numba.core import types from numba.core.typing import signature from numba.cuda.cpython import mathimpl diff --git a/numba_cuda/numba/cuda/cpython/mathimpl.py b/numba_cuda/numba/cuda/cpython/mathimpl.py index f5654c5a1..f5101e4a3 100644 --- a/numba_cuda/numba/cuda/cpython/mathimpl.py +++ b/numba_cuda/numba/cuda/cpython/mathimpl.py @@ -13,8 +13,7 @@ import llvmlite.ir from llvmlite.ir import Constant -from numba.core.imputils import Registry -from numba.cuda.cudaimpl import impl_ret_untracked +from numba.core.imputils import impl_ret_untracked, Registry from numba.core import types, config from numba.core.extending import overload from numba.core.typing import signature diff --git a/numba_cuda/numba/cuda/cpython/numbers.py b/numba_cuda/numba/cuda/cpython/numbers.py index b03df3b0e..2f1d41a07 100644 --- a/numba_cuda/numba/cuda/cpython/numbers.py +++ b/numba_cuda/numba/cuda/cpython/numbers.py @@ -10,16 +10,18 @@ from llvmlite import ir from llvmlite.ir import Constant -from numba.cuda.cudaimpl import lower_cast, lower_constant, impl_ret_untracked -from numba.cuda.cudaimpl import ( - lower as lower_builtin, - lower_attr as lower_getattr, -) +from numba.core.imputils import impl_ret_untracked, Registry from numba.core import typing, types, errors from numba.core.extending import overload_method from numba.cpython.unsafe.numbers import viewer from numba.cuda import cgutils +registry = Registry("numbers") +lower_builtin = registry.lower +lower_cast = registry.lower_cast +lower_constant = registry.lower_constant +lower_getattr = registry.lower_attr + def _int_arith_flags(rettype): """ diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 7f7ece4d1..0417e5334 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -154,7 +154,7 @@ def load_additional_registries(self): from numba.cuda.cpython import numbers # noqa: F401 from numba.cpython import rangeobj, iterators, enumimpl # noqa: F401 from numba.cpython import unicode, charseq # noqa: F401 - from numba.cuda.cpython import cmathimpl + from numba.cuda.cpython import cmathimpl, mathimpl from numba.misc import cffiimpl from numba.np import arrayobj # noqa: F401 from numba.np import npdatetime # noqa: F401 @@ -163,7 +163,7 @@ def load_additional_registries(self): fp16, printimpl, libdeviceimpl, - mathimpl, + mathimpl as cuda_mathimpl, vector_types, bf16, ) @@ -177,6 +177,8 @@ def load_additional_registries(self): self.install_registry(libdeviceimpl.registry) self.install_registry(cmathimpl.registry) self.install_registry(mathimpl.registry) + self.install_registry(numbers.registry) + self.install_registry(cuda_mathimpl.registry) self.install_registry(vector_types.impl_registry) self.install_registry(fp16.target_registry) self.install_registry(bf16.target_registry) From 7a826affeaae03da16ec12fe37b725a2113f557a Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 27 Aug 2025 17:11:45 +0100 Subject: [PATCH 07/10] Undo temp fix --- numba_cuda/numba/cuda/core/typed_passes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numba_cuda/numba/cuda/core/typed_passes.py b/numba_cuda/numba/cuda/core/typed_passes.py index 64b446646..b56497f5a 100644 --- a/numba_cuda/numba/cuda/core/typed_passes.py +++ b/numba_cuda/numba/cuda/core/typed_passes.py @@ -325,7 +325,7 @@ def run_pass(self, state): calltypes = state.calltypes flags = state.flags metadata = state.metadata - pre_stats = llvm.newpassmanagers.dump_refprune_stats() + pre_stats = llvm.passmanagers.dump_refprune_stats() msg = "Function %s failed at nopython mode lowering" % ( state.func_id.func_name, @@ -388,7 +388,7 @@ def run_pass(self, state): ) # capture pruning stats - post_stats = llvm.newpassmanagers.dump_refprune_stats() + post_stats = llvm.passmanagers.dump_refprune_stats() metadata["prune_stats"] = post_stats - pre_stats # Save the LLVM pass timings From 7a877680f65aa8c514f88f6b7ee3e5ef62ad2b91 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Wed, 27 Aug 2025 17:11:57 +0100 Subject: [PATCH 08/10] Fix erroneous attribute name --- numba_cuda/numba/cuda/cpython/numbers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numba_cuda/numba/cuda/cpython/numbers.py b/numba_cuda/numba/cuda/cpython/numbers.py index 2f1d41a07..367d79dda 100644 --- a/numba_cuda/numba/cuda/cpython/numbers.py +++ b/numba_cuda/numba/cuda/cpython/numbers.py @@ -20,7 +20,7 @@ lower_builtin = registry.lower lower_cast = registry.lower_cast lower_constant = registry.lower_constant -lower_getattr = registry.lower_attr +lower_getattr = registry.lower_getattr def _int_arith_flags(rettype): From 82db58d74896d7c2b54ac8bead92dd122b8f31e2 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Thu, 28 Aug 2025 12:46:03 +0100 Subject: [PATCH 09/10] Add import of optional in `load_additional_registries()` This is imported as a side effect in the upstream `numbers.py`, causing registrations to the builtin registry for implementations involving the `Optional` type. The side-effecting import could be added back to numbers.py, but I think the right place to be keeping imports that have registrations as side effects is in `load_additional_registries()` so that they're all in one place and we can work through handling each case. In each vendored module, we should create a new registry and explicitly add it to the CUDA target context, rather than jamming many items arbitrarily into the "builtin registry". --- numba_cuda/numba/cuda/target.py | 1 + 1 file changed, 1 insertion(+) diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index 0417e5334..38077f12f 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -155,6 +155,7 @@ def load_additional_registries(self): from numba.cpython import rangeobj, iterators, enumimpl # noqa: F401 from numba.cpython import unicode, charseq # noqa: F401 from numba.cuda.cpython import cmathimpl, mathimpl + from numba.core import optional # noqa: F401 from numba.misc import cffiimpl from numba.np import arrayobj # noqa: F401 from numba.np import npdatetime # noqa: F401 From 505ac6d8b513592768b14cfac44f90e40253b188 Mon Sep 17 00:00:00 2001 From: Graham Markall Date: Thu, 28 Aug 2025 15:12:51 +0100 Subject: [PATCH 10/10] Remove impl_ret_untracked from cudaimpl --- numba_cuda/numba/cuda/cudaimpl.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/numba_cuda/numba/cuda/cudaimpl.py b/numba_cuda/numba/cuda/cudaimpl.py index e15ddda91..97d8f1bdb 100644 --- a/numba_cuda/numba/cuda/cudaimpl.py +++ b/numba_cuda/numba/cuda/cudaimpl.py @@ -1007,10 +1007,3 @@ def cuda_dispatcher_const(context, builder, ty, pyval): # NumPy register_ufuncs(ufunc_db.get_ufuncs(), lower) - - -def impl_ret_untracked(ctx, builder, retty, ret): - """ - The return type is not a NRT object. - """ - return ret