diff --git a/numba_cuda/numba/cuda/core/inline_closurecall.py b/numba_cuda/numba/cuda/core/inline_closurecall.py index bc311491e..77734bf1d 100644 --- a/numba_cuda/numba/cuda/core/inline_closurecall.py +++ b/numba_cuda/numba/cuda/core/inline_closurecall.py @@ -34,7 +34,7 @@ compute_live_variables, ) from numba.core.imputils import impl_ret_untracked -from numba.core.extending import intrinsic +from numba.cuda.extending import intrinsic from numba.core.typing import signature from numba.cuda.core import postproc, rewrites @@ -1102,7 +1102,7 @@ def codegen(context, builder, sig, args): def codegen(context, builder, sig, args): (value,) = args intp_t = context.get_value_type(types.intp) - from numba.cpython.listobj import ListIterInstance + from numba.cuda.cpython.listobj import ListIterInstance iterobj = ListIterInstance(context, builder, sig.args[0], value) return impl_ret_untracked(context, builder, intp_t, iterobj.size) diff --git a/numba_cuda/numba/cuda/core/pythonapi.py b/numba_cuda/numba/cuda/core/pythonapi.py index eca67bce6..0c2b39121 100644 --- a/numba_cuda/numba/cuda/core/pythonapi.py +++ b/numba_cuda/numba/cuda/core/pythonapi.py @@ -209,10 +209,7 @@ def __init__(self, context, builder): self.longlong = ir.IntType(ctypes.sizeof(ctypes.c_ulonglong) * 8) self.ulonglong = self.longlong self.double = ir.DoubleType() - if config.USE_LEGACY_TYPE_SYSTEM: - self.py_ssize_t = self.context.get_value_type(types.intp) - else: - self.py_ssize_t = self.context.get_value_type(types.c_intp) + self.py_ssize_t = self.context.get_value_type(types.intp) self.cstring = ir.PointerType(ir.IntType(8)) self.gil_state = ir.IntType(_helperlib.py_gil_state_size * 8) self.py_buffer_t = ir.ArrayType( diff --git a/numba_cuda/numba/cuda/core/unsafe/bytes.py b/numba_cuda/numba/cuda/core/unsafe/bytes.py index aedd21f12..b205b1a16 100644 --- a/numba_cuda/numba/cuda/core/unsafe/bytes.py +++ b/numba_cuda/numba/cuda/core/unsafe/bytes.py @@ -6,7 +6,7 @@ operations with bytes and workarounds for limitations enforced in userland. """ -from numba.core.extending import intrinsic +from numba.cuda.extending import intrinsic from llvmlite import ir from numba.core import types from numba.cuda import cgutils diff --git a/numba_cuda/numba/cuda/core/unsafe/eh.py b/numba_cuda/numba/cuda/core/unsafe/eh.py index 6c60ce4ec..416f818ec 100644 --- a/numba_cuda/numba/cuda/core/unsafe/eh.py +++ b/numba_cuda/numba/cuda/core/unsafe/eh.py @@ -7,7 +7,7 @@ from numba.core import types, errors from numba.cuda import cgutils -from numba.core.extending import intrinsic +from numba.cuda.extending import intrinsic @intrinsic diff --git a/numba_cuda/numba/cuda/core/unsafe/refcount.py b/numba_cuda/numba/cuda/core/unsafe/refcount.py index a176a4fdd..844367394 100644 --- a/numba_cuda/numba/cuda/core/unsafe/refcount.py +++ b/numba_cuda/numba/cuda/core/unsafe/refcount.py @@ -9,7 +9,7 @@ from numba.core import types from numba.cuda import cgutils -from numba.core.extending import intrinsic +from numba.cuda.extending import intrinsic _word_type = ir.IntType(64) _pointer_type = ir.PointerType(ir.IntType(8)) diff --git a/numba_cuda/numba/cuda/cpython/charseq.py b/numba_cuda/numba/cuda/cpython/charseq.py new file mode 100644 index 000000000..daba8ae5b --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/charseq.py @@ -0,0 +1,1214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +"""Implements operations on bytes and str (unicode) array items.""" + +import operator +import numpy as np +from llvmlite import ir + +from numba.core import types +from numba.cuda import cgutils +from numba.core.extending import overload, overload_method, register_jitable +from numba.core.imputils import Registry +from numba.cuda.cgutils import is_nonelike +from numba.cuda.cpython import unicode +from numba.cuda.extending import intrinsic + +registry = Registry("charseq") +lower = registry.lower +lower_cast = registry.lower_cast + +# bytes and str arrays items are of type CharSeq and UnicodeCharSeq, +# respectively. See numpy/types/npytypes.py for CharSeq, +# UnicodeCharSeq definitions. The corresponding data models are +# defined in numpy/datamodel/models.py. Boxing/unboxing of item types +# are defined in numpy/targets/boxing.py, see box_unicodecharseq, +# unbox_unicodecharseq, box_charseq, unbox_charseq. + +s1_dtype = np.dtype("S1") +assert s1_dtype.itemsize == 1 +bytes_type = types.Bytes(types.uint8, 1, "C", readonly=True) + +# Currently, NumPy supports only UTF-32 arrays but this may change in +# future and the approach used here for supporting str arrays may need +# a revision depending on how NumPy will support UTF-8 and UTF-16 +# arrays. +u1_dtype = np.dtype("U1") +unicode_byte_width = u1_dtype.itemsize +unicode_uint = {1: np.uint8, 2: np.uint16, 4: np.uint32}[unicode_byte_width] +unicode_kind = { + 1: unicode.PY_UNICODE_1BYTE_KIND, + 2: unicode.PY_UNICODE_2BYTE_KIND, + 4: unicode.PY_UNICODE_4BYTE_KIND, +}[unicode_byte_width] + + +# this is modified version of numba.unicode.make_deref_codegen +def make_deref_codegen(bitsize): + def codegen(context, builder, signature, args): + data, idx = args + rawptr = cgutils.alloca_once_value(builder, value=data) + ptr = builder.bitcast(rawptr, ir.IntType(bitsize).as_pointer()) + ch = builder.load(builder.gep(ptr, [idx])) + return builder.zext(ch, ir.IntType(32)) + + return codegen + + +@intrinsic +def deref_uint8(typingctx, data, offset): + sig = types.uint32(data, types.intp) + return sig, make_deref_codegen(8) + + +@intrinsic +def deref_uint16(typingctx, data, offset): + sig = types.uint32(data, types.intp) + return sig, make_deref_codegen(16) + + +@intrinsic +def deref_uint32(typingctx, data, offset): + sig = types.uint32(data, types.intp) + return sig, make_deref_codegen(32) + + +@register_jitable(_nrt=False) +def charseq_get_code(a, i): + """Access i-th item of CharSeq object via code value""" + return deref_uint8(a, i) + + +@register_jitable +def charseq_get_value(a, i): + """Access i-th item of CharSeq object via code value. + + null code is interpreted as IndexError + """ + code = charseq_get_code(a, i) + if code == 0: + raise IndexError("index out of range") + return code + + +@register_jitable(_nrt=False) +def unicode_charseq_get_code(a, i): + """Access i-th item of UnicodeCharSeq object via code value""" + if unicode_byte_width == 4: + return deref_uint32(a, i) + elif unicode_byte_width == 2: + return deref_uint16(a, i) + elif unicode_byte_width == 1: + return deref_uint8(a, i) + else: + raise NotImplementedError( + "unicode_charseq_get_code: unicode_byte_width not in [1, 2, 4]" + ) + + +@register_jitable +def unicode_get_code(a, i): + """Access i-th item of UnicodeType object.""" + return unicode._get_code_point(a, i) + + +@register_jitable +def bytes_get_code(a, i): + """Access i-th item of Bytes object.""" + return a[i] + + +def _get_code_impl(a): + if isinstance(a, types.CharSeq): + return charseq_get_code + elif isinstance(a, types.Bytes): + return bytes_get_code + elif isinstance(a, types.UnicodeCharSeq): + return unicode_charseq_get_code + elif isinstance(a, types.UnicodeType): + return unicode_get_code + + +def _same_kind(a, b): + for t in [ + (types.CharSeq, types.Bytes), + (types.UnicodeCharSeq, types.UnicodeType), + ]: + if isinstance(a, t) and isinstance(b, t): + return True + return False + + +def _is_bytes(a): + return isinstance(a, (types.CharSeq, types.Bytes)) + + +def is_default(x, default): + return x == default or isinstance(x, types.Omitted) + + +@register_jitable +def unicode_charseq_get_value(a, i): + """Access i-th item of UnicodeCharSeq object via unicode value + + null code is interpreted as IndexError + """ + code = unicode_charseq_get_code(a, i) + if code == 0: + raise IndexError("index out of range") + # Return numpy equivalent of `chr(code)` + return np.array(code, unicode_uint).view(u1_dtype)[()] + + +# +# CAST +# +# Currently, the following casting operations are supported: +# Bytes -> CharSeq (ex: a=np.array(b'abc'); a[()] = b'123') +# UnicodeType -> UnicodeCharSeq (ex: a=np.array('abc'); a[()] = '123') +# CharSeq -> Bytes (ex: a=np.array(b'abc'); b = bytes(a[()])) +# UnicodeType -> Bytes (ex: str('123')._to_bytes()) +# +# The following casting operations can be implemented when required: +# Bytes -> UnicodeCharSeq (ex: a=np.array('abc'); a[()] = b'123') +# UnicodeType -> CharSeq (ex: a=np.array(b'abc'); a[()] = '123') +# UnicodeType -> Bytes (ex: bytes('123', 'utf8')) +# + + +@lower_cast(types.Bytes, types.CharSeq) +def bytes_to_charseq(context, builder, fromty, toty, val): + barr = cgutils.create_struct_proxy(fromty)(context, builder, value=val) + src = builder.bitcast(barr.data, ir.IntType(8).as_pointer()) + src_length = barr.nitems + + lty = context.get_value_type(toty) + dstint_t = ir.IntType(8) + dst_ptr = cgutils.alloca_once(builder, lty) + dst = builder.bitcast(dst_ptr, dstint_t.as_pointer()) + + dst_length = ir.Constant(src_length.type, toty.count) + is_shorter_value = builder.icmp_unsigned("<", src_length, dst_length) + count = builder.select(is_shorter_value, src_length, dst_length) + with builder.if_then(is_shorter_value): + cgutils.memset( + builder, dst, ir.Constant(src_length.type, toty.count), 0 + ) + with cgutils.for_range(builder, count) as loop: + in_ptr = builder.gep(src, [loop.index]) + in_val = builder.zext(builder.load(in_ptr), dstint_t) + builder.store(in_val, builder.gep(dst, [loop.index])) + + return builder.load(dst_ptr) + + +def _make_constant_bytes(context, builder, nbytes): + bstr_ctor = cgutils.create_struct_proxy(bytes_type) + bstr = bstr_ctor(context, builder) + + if isinstance(nbytes, int): + nbytes = ir.Constant(bstr.nitems.type, nbytes) + + bstr.meminfo = context.nrt.meminfo_alloc(builder, nbytes) + bstr.nitems = nbytes + bstr.itemsize = ir.Constant(bstr.itemsize.type, 1) + bstr.data = context.nrt.meminfo_data(builder, bstr.meminfo) + bstr.parent = cgutils.get_null_value(bstr.parent.type) + # bstr.shape and bstr.strides are not used + bstr.shape = cgutils.get_null_value(bstr.shape.type) + bstr.strides = cgutils.get_null_value(bstr.strides.type) + return bstr + + +@lower_cast(types.CharSeq, types.Bytes) +def charseq_to_bytes(context, builder, fromty, toty, val): + bstr = _make_constant_bytes(context, builder, val.type.count) + rawptr = cgutils.alloca_once_value(builder, value=val) + ptr = builder.bitcast(rawptr, bstr.data.type) + cgutils.memcpy(builder, bstr.data, ptr, bstr.nitems) + return bstr + + +@lower_cast(types.UnicodeType, types.Bytes) +def unicode_to_bytes_cast(context, builder, fromty, toty, val): + uni_str = cgutils.create_struct_proxy(fromty)(context, builder, value=val) + src1 = builder.bitcast(uni_str.data, ir.IntType(8).as_pointer()) + notkind1 = builder.icmp_unsigned( + "!=", uni_str.kind, ir.Constant(uni_str.kind.type, 1) + ) + src_length = uni_str.length + + with builder.if_then(notkind1): + context.call_conv.return_user_exc( + builder, + ValueError, + ("cannot cast higher than 8-bit unicode_type to bytes",), + ) + + bstr = _make_constant_bytes(context, builder, src_length) + cgutils.memcpy(builder, bstr.data, src1, bstr.nitems) + return bstr + + +@intrinsic +def _unicode_to_bytes(typingctx, s): + # used in _to_bytes method + assert s == types.unicode_type + sig = bytes_type(s) + + def codegen(context, builder, signature, args): + return unicode_to_bytes_cast( + context, builder, s, bytes_type, args[0] + )._getvalue() + + return sig, codegen + + +@lower_cast(types.UnicodeType, types.UnicodeCharSeq) +def unicode_to_unicode_charseq(context, builder, fromty, toty, val): + uni_str = cgutils.create_struct_proxy(fromty)(context, builder, value=val) + src1 = builder.bitcast(uni_str.data, ir.IntType(8).as_pointer()) + src2 = builder.bitcast(uni_str.data, ir.IntType(16).as_pointer()) + src4 = builder.bitcast(uni_str.data, ir.IntType(32).as_pointer()) + kind1 = builder.icmp_unsigned( + "==", uni_str.kind, ir.Constant(uni_str.kind.type, 1) + ) + kind2 = builder.icmp_unsigned( + "==", uni_str.kind, ir.Constant(uni_str.kind.type, 2) + ) + kind4 = builder.icmp_unsigned( + "==", uni_str.kind, ir.Constant(uni_str.kind.type, 4) + ) + src_length = uni_str.length + + lty = context.get_value_type(toty) + dstint_t = ir.IntType(8 * unicode_byte_width) + dst_ptr = cgutils.alloca_once(builder, lty) + dst = builder.bitcast(dst_ptr, dstint_t.as_pointer()) + + dst_length = ir.Constant(src_length.type, toty.count) + is_shorter_value = builder.icmp_unsigned("<", src_length, dst_length) + count = builder.select(is_shorter_value, src_length, dst_length) + with builder.if_then(is_shorter_value): + cgutils.memset( + builder, + dst, + ir.Constant(src_length.type, toty.count * unicode_byte_width), + 0, + ) + + with builder.if_then(kind1): + with cgutils.for_range(builder, count) as loop: + in_ptr = builder.gep(src1, [loop.index]) + in_val = builder.zext(builder.load(in_ptr), dstint_t) + builder.store(in_val, builder.gep(dst, [loop.index])) + + with builder.if_then(kind2): + if unicode_byte_width >= 2: + with cgutils.for_range(builder, count) as loop: + in_ptr = builder.gep(src2, [loop.index]) + in_val = builder.zext(builder.load(in_ptr), dstint_t) + builder.store(in_val, builder.gep(dst, [loop.index])) + else: + context.call_conv.return_user_exc( + builder, + ValueError, + ( + "cannot cast 16-bit unicode_type to %s-bit %s" + % (unicode_byte_width * 8, toty) + ), + ) + + with builder.if_then(kind4): + if unicode_byte_width >= 4: + with cgutils.for_range(builder, count) as loop: + in_ptr = builder.gep(src4, [loop.index]) + in_val = builder.zext(builder.load(in_ptr), dstint_t) + builder.store(in_val, builder.gep(dst, [loop.index])) + else: + context.call_conv.return_user_exc( + builder, + ValueError, + ( + "cannot cast 32-bit unicode_type to %s-bit %s" + % (unicode_byte_width * 8, toty) + ), + ) + + return builder.load(dst_ptr) + + +# +# Operations on bytes/str array items +# +# Implementation note: while some operations need +# CharSeq/UnicodeCharSeq specific implementations (getitem, len, str, +# etc), many operations can be supported by casting +# CharSeq/UnicodeCharSeq objects to Bytes/UnicodeType objects and +# re-use existing operations. +# +# However, in numba more operations are implemented for UnicodeType +# than for Bytes objects, hence the support for operations with bytes +# array items will be less complete than for str arrays. Although, in +# some cases (hash, contains, etc) the UnicodeType implementations can +# be reused for Bytes objects via using `_to_str` method. +# + + +@overload(operator.getitem) +def charseq_getitem(s, i): + get_value = None + if isinstance(i, types.Integer): + if isinstance(s, types.CharSeq): + get_value = charseq_get_value + if isinstance(s, types.UnicodeCharSeq): + get_value = unicode_charseq_get_value + if get_value is not None: + max_i = s.count + msg = "index out of range [0, %s]" % (max_i - 1) + + def getitem_impl(s, i): + if i < max_i and i >= 0: + return get_value(s, i) + raise IndexError(msg) + + return getitem_impl + + +@overload(len) +def charseq_len(s): + if isinstance(s, (types.CharSeq, types.UnicodeCharSeq)): + get_code = _get_code_impl(s) + n = s.count + if n == 0: + + def len_impl(s): + return 0 + + return len_impl + else: + + def len_impl(s): + # return the index of the last non-null value (numpy + # behavior) + i = n + code = 0 + while code == 0: + i = i - 1 + if i < 0: + break + code = get_code(s, i) + return i + 1 + + return len_impl + + +@overload(operator.add) +@overload(operator.iadd) +def charseq_concat(a, b): + if not _same_kind(a, b): + return + if isinstance(a, types.UnicodeCharSeq) and isinstance(b, types.UnicodeType): + + def impl(a, b): + return str(a) + b + + return impl + if isinstance(b, types.UnicodeCharSeq) and isinstance(a, types.UnicodeType): + + def impl(a, b): + return a + str(b) + + return impl + if isinstance(a, types.UnicodeCharSeq) and isinstance( + b, types.UnicodeCharSeq + ): + + def impl(a, b): + return str(a) + str(b) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)) and isinstance( + b, (types.CharSeq, types.Bytes) + ): + + def impl(a, b): + return (a._to_str() + b._to_str())._to_bytes() + + return impl + + +@overload(operator.mul) +def charseq_repeat(a, b): + if isinstance(a, types.UnicodeCharSeq): + + def wrap(a, b): + return str(a) * b + + return wrap + if isinstance(b, types.UnicodeCharSeq): + + def wrap(a, b): + return a * str(b) + + return wrap + if isinstance(a, (types.CharSeq, types.Bytes)): + + def wrap(a, b): + return (a._to_str() * b)._to_bytes() + + return wrap + if isinstance(b, (types.CharSeq, types.Bytes)): + + def wrap(a, b): + return (a * b._to_str())._to_bytes() + + return wrap + + +@overload(operator.not_) +def charseq_not(a): + if isinstance(a, (types.UnicodeCharSeq, types.CharSeq, types.Bytes)): + + def impl(a): + return len(a) == 0 + + return impl + + +@overload(operator.eq) +def charseq_eq(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + + def eq_impl(a, b): + n = len(a) + if n != len(b): + return False + for i in range(n): + if left_code(a, i) != right_code(b, i): + return False + return True + + return eq_impl + + +@overload(operator.ne) +def charseq_ne(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + + def ne_impl(a, b): + return not (a == b) + + return ne_impl + + +@overload(operator.lt) +def charseq_lt(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + + def lt_impl(a, b): + na = len(a) + nb = len(b) + n = min(na, nb) + for i in range(n): + ca, cb = left_code(a, i), right_code(b, i) + if ca != cb: + return ca < cb + return na < nb + + return lt_impl + + +@overload(operator.gt) +def charseq_gt(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + + def gt_impl(a, b): + return b < a + + return gt_impl + + +@overload(operator.le) +def charseq_le(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + + def le_impl(a, b): + return not (a > b) + + return le_impl + + +@overload(operator.ge) +def charseq_ge(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + + def ge_impl(a, b): + return not (a < b) + + return ge_impl + + +@overload(operator.contains) +def charseq_contains(a, b): + if not _same_kind(a, b): + return + left_code = _get_code_impl(a) + right_code = _get_code_impl(b) + if left_code is not None and right_code is not None: + if _is_bytes(a): + + def contains_impl(a, b): + # Ideally, `return bytes(b) in bytes(a)` would be used + # here, but numba Bytes does not implement + # contains. So, using `unicode_type` implementation + # here: + return b._to_str() in a._to_str() + else: + + def contains_impl(a, b): + return str(b) in str(a) + + return contains_impl + + +@overload_method(types.UnicodeCharSeq, "isascii") +@overload_method(types.CharSeq, "isascii") +@overload_method(types.Bytes, "isascii") +def charseq_isascii(s): + get_code = _get_code_impl(s) + + def impl(s): + for i in range(len(s)): + if get_code(s, i) > 127: + return False + return True + + return impl + + +@overload_method(types.UnicodeCharSeq, "_get_kind") +@overload_method(types.CharSeq, "_get_kind") +def charseq_get_kind(s): + get_code = _get_code_impl(s) + + def impl(s): + max_code = 0 + for i in range(len(s)): + code = get_code(s, i) + if code > max_code: + max_code = code + if max_code > 0xFFFF: + return unicode.PY_UNICODE_4BYTE_KIND + if max_code > 0xFF: + return unicode.PY_UNICODE_2BYTE_KIND + return unicode.PY_UNICODE_1BYTE_KIND + + return impl + + +@overload_method(types.UnicodeType, "_to_bytes") +def unicode_to_bytes_mth(s): + """Convert unicode_type object to Bytes object. + + Note: The usage of _to_bytes method can be eliminated once all + Python bytes operations are implemented for numba Bytes objects. + + """ + + def impl(s): + return _unicode_to_bytes(s) + + return impl + + +@overload_method(types.CharSeq, "_to_str") +@overload_method(types.Bytes, "_to_str") +def charseq_to_str_mth(s): + """Convert bytes array item or bytes instance to UTF-8 str. + + Note: The usage of _to_str method can be eliminated once all + Python bytes operations are implemented for numba Bytes objects. + """ + get_code = _get_code_impl(s) + + def tostr_impl(s): + n = len(s) + is_ascii = s.isascii() + result = unicode._empty_string( + unicode.PY_UNICODE_1BYTE_KIND, n, is_ascii + ) + for i in range(n): + code = get_code(s, i) + unicode._set_code_point(result, i, code) + return result + + return tostr_impl + + +@overload_method(types.UnicodeCharSeq, "__str__") +def charseq_str(s): + get_code = _get_code_impl(s) + + def str_impl(s): + n = len(s) + kind = s._get_kind() + is_ascii = kind == 1 and s.isascii() + result = unicode._empty_string(kind, n, is_ascii) + for i in range(n): + code = get_code(s, i) + unicode._set_code_point(result, i, code) + return result + + return str_impl + + +@overload(bytes) +def charseq_bytes(s): + if isinstance(s, types.CharSeq): + return lambda s: s + + +@overload_method(types.UnicodeCharSeq, "__hash__") +def unicode_charseq_hash(s): + def impl(s): + return hash(str(s)) + + return impl + + +@overload_method(types.CharSeq, "__hash__") +def charseq_hash(s): + def impl(s): + # Ideally, `return hash(bytes(s))` would be used here but + # numba Bytes does not implement hash (yet). However, for a + # UTF-8 string `s`, we have hash(bytes(s)) == hash(s), hence, + # we can convert CharSeq object to unicode_type and reuse its + # hash implementation: + return hash(s._to_str()) + + return impl + + +@overload_method(types.UnicodeCharSeq, "isupper") +def unicode_charseq_isupper(s): + def impl(s): + # workaround unicode_type.isupper bug: it returns int value + return not not str(s).isupper() + + return impl + + +@overload_method(types.CharSeq, "isupper") +def charseq_isupper(s): + def impl(s): + # return bytes(s).isupper() # TODO: implement isupper for Bytes + return not not s._to_str().isupper() + + return impl + + +@overload_method(types.UnicodeCharSeq, "upper") +def unicode_charseq_upper(s): + def impl(s): + return str(s).upper() + + return impl + + +@overload_method(types.CharSeq, "upper") +def charseq_upper(s): + def impl(s): + # return bytes(s).upper() # TODO: implement upper for Bytes + return s._to_str().upper()._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "find") +@overload_method(types.CharSeq, "find") +@overload_method(types.Bytes, "find") +def unicode_charseq_find(a, b): + if isinstance(a, types.UnicodeCharSeq): + if isinstance(b, types.UnicodeCharSeq): + + def impl(a, b): + return str(a).find(str(b)) + + return impl + if isinstance(b, types.UnicodeType): + + def impl(a, b): + return str(a).find(b) + + return impl + if isinstance(a, types.CharSeq): + if isinstance(b, (types.CharSeq, types.Bytes)): + + def impl(a, b): + return a._to_str().find(b._to_str()) + + return impl + if isinstance(a, types.UnicodeType): + if isinstance(b, types.UnicodeCharSeq): + + def impl(a, b): + return a.find(str(b)) + + return impl + if isinstance(a, types.Bytes): + if isinstance(b, types.CharSeq): + + def impl(a, b): + return a._to_str().find(b._to_str()) + + return impl + + +@overload_method(types.UnicodeCharSeq, "rfind") +@overload_method(types.CharSeq, "rfind") +@overload_method(types.Bytes, "rfind") +def unicode_charseq_rfind(a, b): + if isinstance(a, types.UnicodeCharSeq): + if isinstance(b, types.UnicodeCharSeq): + + def impl(a, b): + return str(a).rfind(str(b)) + + return impl + if isinstance(b, types.UnicodeType): + + def impl(a, b): + return str(a).rfind(b) + + return impl + if isinstance(a, types.CharSeq): + if isinstance(b, (types.CharSeq, types.Bytes)): + + def impl(a, b): + return a._to_str().rfind(b._to_str()) + + return impl + if isinstance(a, types.UnicodeType): + if isinstance(b, types.UnicodeCharSeq): + + def impl(a, b): + return a.rfind(str(b)) + + return impl + if isinstance(a, types.Bytes): + if isinstance(b, types.CharSeq): + + def impl(a, b): + return a._to_str().rfind(b._to_str()) + + return impl + + +@overload_method(types.UnicodeCharSeq, "startswith") +@overload_method(types.CharSeq, "startswith") +@overload_method(types.Bytes, "startswith") +def unicode_charseq_startswith(a, b): + if isinstance(a, types.UnicodeCharSeq): + if isinstance(b, types.UnicodeCharSeq): + + def impl(a, b): + return str(a).startswith(str(b)) + + return impl + if isinstance(b, types.UnicodeType): + + def impl(a, b): + return str(a).startswith(b) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if isinstance(b, (types.CharSeq, types.Bytes)): + + def impl(a, b): + return a._to_str().startswith(b._to_str()) + + return impl + + +@overload_method(types.UnicodeCharSeq, "endswith") +@overload_method(types.CharSeq, "endswith") +@overload_method(types.Bytes, "endswith") +def unicode_charseq_endswith(a, b): + if isinstance(a, types.UnicodeCharSeq): + if isinstance(b, types.UnicodeCharSeq): + + def impl(a, b): + return str(a).endswith(str(b)) + + return impl + if isinstance(b, types.UnicodeType): + + def impl(a, b): + return str(a).endswith(b) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if isinstance(b, (types.CharSeq, types.Bytes)): + + def impl(a, b): + return a._to_str().endswith(b._to_str()) + + return impl + + +@register_jitable +def _map_bytes(seq): + return [s._to_bytes() for s in seq] + + +@overload_method(types.UnicodeCharSeq, "split") +@overload_method(types.CharSeq, "split") +@overload_method(types.Bytes, "split") +def unicode_charseq_split(a, sep=None, maxsplit=-1): + if not ( + maxsplit == -1 + or isinstance( + maxsplit, (types.Omitted, types.Integer, types.IntegerLiteral) + ) + ): + return None + if isinstance(a, types.UnicodeCharSeq): + if isinstance(sep, types.UnicodeCharSeq): + + def impl(a, sep=None, maxsplit=-1): + return str(a).split(sep=str(sep), maxsplit=maxsplit) + + return impl + if isinstance(sep, types.UnicodeType): + + def impl(a, sep=None, maxsplit=-1): + return str(a).split(sep=sep, maxsplit=maxsplit) + + return impl + if is_nonelike(sep): + if is_default(maxsplit, -1): + + def impl(a, sep=None, maxsplit=-1): + return str(a).split() + else: + + def impl(a, sep=None, maxsplit=-1): + return str(a).split(maxsplit=maxsplit) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if isinstance(sep, (types.CharSeq, types.Bytes)): + + def impl(a, sep=None, maxsplit=-1): + return _map_bytes( + a._to_str().split(sep._to_str(), maxsplit=maxsplit) + ) + + return impl + if is_nonelike(sep): + if is_default(maxsplit, -1): + + def impl(a, sep=None, maxsplit=-1): + return _map_bytes(a._to_str().split()) + else: + + def impl(a, sep=None, maxsplit=-1): + return _map_bytes(a._to_str().split(maxsplit=maxsplit)) + + return impl + + +# NOT IMPLEMENTED: rsplit + + +@overload_method(types.UnicodeCharSeq, "ljust") +@overload_method(types.CharSeq, "ljust") +@overload_method(types.Bytes, "ljust") +def unicode_charseq_ljust(a, width, fillchar=" "): + if isinstance(a, types.UnicodeCharSeq): + if is_default(fillchar, " "): + + def impl(a, width, fillchar=" "): + return str(a).ljust(width) + + return impl + elif isinstance(fillchar, types.UnicodeCharSeq): + + def impl(a, width, fillchar=" "): + return str(a).ljust(width, str(fillchar)) + + return impl + elif isinstance(fillchar, types.UnicodeType): + + def impl(a, width, fillchar=" "): + return str(a).ljust(width, fillchar) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if is_default(fillchar, " ") or is_default(fillchar, b" "): + + def impl(a, width, fillchar=" "): + return a._to_str().ljust(width)._to_bytes() + + return impl + elif isinstance(fillchar, (types.CharSeq, types.Bytes)): + + def impl(a, width, fillchar=" "): + return a._to_str().ljust(width, fillchar._to_str())._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "rjust") +@overload_method(types.CharSeq, "rjust") +@overload_method(types.Bytes, "rjust") +def unicode_charseq_rjust(a, width, fillchar=" "): + if isinstance(a, types.UnicodeCharSeq): + if is_default(fillchar, " "): + + def impl(a, width, fillchar=" "): + return str(a).rjust(width) + + return impl + elif isinstance(fillchar, types.UnicodeCharSeq): + + def impl(a, width, fillchar=" "): + return str(a).rjust(width, str(fillchar)) + + return impl + elif isinstance(fillchar, types.UnicodeType): + + def impl(a, width, fillchar=" "): + return str(a).rjust(width, fillchar) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if is_default(fillchar, " ") or is_default(fillchar, b" "): + + def impl(a, width, fillchar=" "): + return a._to_str().rjust(width)._to_bytes() + + return impl + elif isinstance(fillchar, (types.CharSeq, types.Bytes)): + + def impl(a, width, fillchar=" "): + return a._to_str().rjust(width, fillchar._to_str())._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "center") +@overload_method(types.CharSeq, "center") +@overload_method(types.Bytes, "center") +def unicode_charseq_center(a, width, fillchar=" "): + if isinstance(a, types.UnicodeCharSeq): + if is_default(fillchar, " "): + + def impl(a, width, fillchar=" "): + return str(a).center(width) + + return impl + elif isinstance(fillchar, types.UnicodeCharSeq): + + def impl(a, width, fillchar=" "): + return str(a).center(width, str(fillchar)) + + return impl + elif isinstance(fillchar, types.UnicodeType): + + def impl(a, width, fillchar=" "): + return str(a).center(width, fillchar) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if is_default(fillchar, " ") or is_default(fillchar, b" "): + + def impl(a, width, fillchar=" "): + return a._to_str().center(width)._to_bytes() + + return impl + elif isinstance(fillchar, (types.CharSeq, types.Bytes)): + + def impl(a, width, fillchar=" "): + return a._to_str().center(width, fillchar._to_str())._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "zfill") +@overload_method(types.CharSeq, "zfill") +@overload_method(types.Bytes, "zfill") +def unicode_charseq_zfill(a, width): + if isinstance(a, types.UnicodeCharSeq): + + def impl(a, width): + return str(a).zfill(width) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + + def impl(a, width): + return a._to_str().zfill(width)._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "lstrip") +@overload_method(types.CharSeq, "lstrip") +@overload_method(types.Bytes, "lstrip") +def unicode_charseq_lstrip(a, chars=None): + if isinstance(a, types.UnicodeCharSeq): + if is_nonelike(chars): + + def impl(a, chars=None): + return str(a).lstrip() + + return impl + elif isinstance(chars, types.UnicodeCharSeq): + + def impl(a, chars=None): + return str(a).lstrip(str(chars)) + + return impl + elif isinstance(chars, types.UnicodeType): + + def impl(a, chars=None): + return str(a).lstrip(chars) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if is_nonelike(chars): + + def impl(a, chars=None): + return a._to_str().lstrip()._to_bytes() + + return impl + elif isinstance(chars, (types.CharSeq, types.Bytes)): + + def impl(a, chars=None): + return a._to_str().lstrip(chars._to_str())._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "rstrip") +@overload_method(types.CharSeq, "rstrip") +@overload_method(types.Bytes, "rstrip") +def unicode_charseq_rstrip(a, chars=None): + if isinstance(a, types.UnicodeCharSeq): + if is_nonelike(chars): + + def impl(a, chars=None): + return str(a).rstrip() + + return impl + elif isinstance(chars, types.UnicodeCharSeq): + + def impl(a, chars=None): + return str(a).rstrip(str(chars)) + + return impl + elif isinstance(chars, types.UnicodeType): + + def impl(a, chars=None): + return str(a).rstrip(chars) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if is_nonelike(chars): + + def impl(a, chars=None): + return a._to_str().rstrip()._to_bytes() + + return impl + elif isinstance(chars, (types.CharSeq, types.Bytes)): + + def impl(a, chars=None): + return a._to_str().rstrip(chars._to_str())._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "strip") +@overload_method(types.CharSeq, "strip") +@overload_method(types.Bytes, "strip") +def unicode_charseq_strip(a, chars=None): + if isinstance(a, types.UnicodeCharSeq): + if is_nonelike(chars): + + def impl(a, chars=None): + return str(a).strip() + + return impl + elif isinstance(chars, types.UnicodeCharSeq): + + def impl(a, chars=None): + return str(a).strip(str(chars)) + + return impl + elif isinstance(chars, types.UnicodeType): + + def impl(a, chars=None): + return str(a).strip(chars) + + return impl + if isinstance(a, (types.CharSeq, types.Bytes)): + if is_nonelike(chars): + + def impl(a, chars=None): + return a._to_str().strip()._to_bytes() + + return impl + elif isinstance(chars, (types.CharSeq, types.Bytes)): + + def impl(a, chars=None): + return a._to_str().strip(chars._to_str())._to_bytes() + + return impl + + +@overload_method(types.UnicodeCharSeq, "join") +@overload_method(types.CharSeq, "join") +@overload_method(types.Bytes, "join") +def unicode_charseq_join(a, parts): + if isinstance(a, types.UnicodeCharSeq): + # assuming parts contains UnicodeCharSeq or UnicodeType objects + def impl(a, parts): + _parts = [str(p) for p in parts] + return str(a).join(_parts) + + return impl + + if isinstance(a, (types.CharSeq, types.Bytes)): + # assuming parts contains CharSeq or Bytes objects + def impl(a, parts): + _parts = [p._to_str() for p in parts] + return a._to_str().join(_parts)._to_bytes() + + return impl diff --git a/numba_cuda/numba/cuda/cpython/iterators.py b/numba_cuda/numba/cuda/cpython/iterators.py new file mode 100644 index 000000000..b5615accd --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/iterators.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implementation of various iterable and iterator types. +""" + +from numba.core import types +from numba.cuda import cgutils +from numba.core.imputils import ( + iternext_impl, + call_iternext, + call_getiter, + impl_ret_borrowed, + impl_ret_new_ref, + RefType, + Registry, +) + + +registry = Registry("iterators") +lower = registry.lower +lower_cast = registry.lower_cast +lower_constant = registry.lower_constant +lower_getattr = registry.lower_getattr + + +@lower("getiter", types.IteratorType) +def iterator_getiter(context, builder, sig, args): + [it] = args + return impl_ret_borrowed(context, builder, sig.return_type, it) + + +# ------------------------------------------------------------------------------- +# builtin `enumerate` implementation + + +@lower(enumerate, types.IterableType) +@lower(enumerate, types.IterableType, types.Integer) +def make_enumerate_object(context, builder, sig, args): + assert ( + len(args) == 1 or len(args) == 2 + ) # enumerate(it) or enumerate(it, start) + srcty = sig.args[0] + + if len(args) == 1: + src = args[0] + start_val = context.get_constant(types.intp, 0) + elif len(args) == 2: + src = args[0] + start_val = context.cast(builder, args[1], sig.args[1], types.intp) + + iterobj = call_getiter(context, builder, srcty, src) + + enum = context.make_helper(builder, sig.return_type) + + countptr = cgutils.alloca_once(builder, start_val.type) + builder.store(start_val, countptr) + + enum.count = countptr + enum.iter = iterobj + + res = enum._getvalue() + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +@lower("iternext", types.EnumerateType) +@iternext_impl(RefType.NEW) +def iternext_enumerate(context, builder, sig, args, result): + [enumty] = sig.args + [enum] = args + + enum = context.make_helper(builder, enumty, value=enum) + + count = builder.load(enum.count) + ncount = builder.add(count, context.get_constant(types.intp, 1)) + builder.store(ncount, enum.count) + + srcres = call_iternext(context, builder, enumty.source_type, enum.iter) + is_valid = srcres.is_valid() + result.set_valid(is_valid) + + with builder.if_then(is_valid): + srcval = srcres.yielded_value() + result.yield_( + context.make_tuple(builder, enumty.yield_type, [count, srcval]) + ) + + +# ------------------------------------------------------------------------------- +# builtin `zip` implementation + + +@lower(zip, types.VarArg(types.Any)) +def make_zip_object(context, builder, sig, args): + zip_type = sig.return_type + + assert len(args) == len(zip_type.source_types) + + zipobj = context.make_helper(builder, zip_type) + + for i, (arg, srcty) in enumerate(zip(args, sig.args)): + zipobj[i] = call_getiter(context, builder, srcty, arg) + + res = zipobj._getvalue() + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +@lower("iternext", types.ZipType) +@iternext_impl(RefType.NEW) +def iternext_zip(context, builder, sig, args, result): + [zip_type] = sig.args + [zipobj] = args + + zipobj = context.make_helper(builder, zip_type, value=zipobj) + + if len(zipobj) == 0: + # zip() is an empty iterator + result.set_exhausted() + return + + p_ret_tup = cgutils.alloca_once( + builder, context.get_value_type(zip_type.yield_type) + ) + p_is_valid = cgutils.alloca_once_value(builder, value=cgutils.true_bit) + + for i, (iterobj, srcty) in enumerate(zip(zipobj, zip_type.source_types)): + is_valid = builder.load(p_is_valid) + # Avoid calling the remaining iternext if a iterator has been exhausted + with builder.if_then(is_valid): + srcres = call_iternext(context, builder, srcty, iterobj) + is_valid = builder.and_(is_valid, srcres.is_valid()) + builder.store(is_valid, p_is_valid) + val = srcres.yielded_value() + ptr = cgutils.gep_inbounds(builder, p_ret_tup, 0, i) + builder.store(val, ptr) + + is_valid = builder.load(p_is_valid) + result.set_valid(is_valid) + + with builder.if_then(is_valid): + result.yield_(builder.load(p_ret_tup)) + + +# ------------------------------------------------------------------------------- +# generator implementation + + +@lower("iternext", types.Generator) +@iternext_impl(RefType.BORROWED) +def iternext_zip(context, builder, sig, args, result): # noqa: F811 + (genty,) = sig.args + (gen,) = args + impl = context.get_generator_impl(genty) + status, retval = impl(context, builder, sig, args) + context.add_linking_libs(getattr(impl, "libs", ())) + + with cgutils.if_likely(builder, status.is_ok): + result.set_valid(True) + result.yield_(retval) + with cgutils.if_unlikely(builder, status.is_stop_iteration): + result.set_exhausted() + with cgutils.if_unlikely( + builder, + builder.and_(status.is_error, builder.not_(status.is_stop_iteration)), + ): + context.call_conv.return_status_propagate(builder, status) diff --git a/numba_cuda/numba/cuda/cpython/listobj.py b/numba_cuda/numba/cuda/cpython/listobj.py new file mode 100644 index 000000000..d00c0295c --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/listobj.py @@ -0,0 +1,1362 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Support for native homogeneous lists. +""" + +import operator + +from llvmlite import ir +from numba.core import types, errors +from numba.cuda import cgutils +from numba.core.imputils import ( + Registry, + iternext_impl, + impl_ret_borrowed, + impl_ret_new_ref, + RefType, +) +from numba.core.extending import overload_method, overload +from numba.misc import quicksort +from numba.cpython import slicing +from numba import literal_unroll + +registry = Registry("listobj") +lower = registry.lower +lower_cast = registry.lower_cast + + +def get_list_payload(context, builder, list_type, value): + """ + Given a list value and type, get its payload structure (as a + reference, so that mutations are seen by all). + """ + payload_type = types.ListPayload(list_type) + payload = context.nrt.meminfo_data(builder, value.meminfo) + ptrty = context.get_data_type(payload_type).as_pointer() + payload = builder.bitcast(payload, ptrty) + return context.make_data_helper(builder, payload_type, ref=payload) + + +def get_itemsize(context, list_type): + """ + Return the item size for the given list type. + """ + llty = context.get_data_type(list_type.dtype) + return context.get_abi_sizeof(llty) + + +class _ListPayloadMixin(object): + @property + def size(self): + return self._payload.size + + @size.setter + def size(self, value): + self._payload.size = value + + @property + def dirty(self): + return self._payload.dirty + + @property + def data(self): + return self._payload._get_ptr_by_name("data") + + def _gep(self, idx): + return cgutils.gep(self._builder, self.data, idx) + + def getitem(self, idx): + ptr = self._gep(idx) + data_item = self._builder.load(ptr) + return self._datamodel.from_data(self._builder, data_item) + + def fix_index(self, idx): + """ + Fix negative indices by adding the size to them. Positive + indices are left untouched. + """ + is_negative = self._builder.icmp_signed( + "<", idx, ir.Constant(idx.type, 0) + ) + wrapped_index = self._builder.add(idx, self.size) + return self._builder.select(is_negative, wrapped_index, idx) + + def is_out_of_bounds(self, idx): + """ + Return whether the index is out of bounds. + """ + underflow = self._builder.icmp_signed( + "<", idx, ir.Constant(idx.type, 0) + ) + overflow = self._builder.icmp_signed(">=", idx, self.size) + return self._builder.or_(underflow, overflow) + + def clamp_index(self, idx): + """ + Clamp the index in [0, size]. + """ + builder = self._builder + idxptr = cgutils.alloca_once_value(builder, idx) + + zero = ir.Constant(idx.type, 0) + size = self.size + + underflow = self._builder.icmp_signed("<", idx, zero) + with builder.if_then(underflow, likely=False): + builder.store(zero, idxptr) + overflow = self._builder.icmp_signed(">=", idx, size) + with builder.if_then(overflow, likely=False): + builder.store(size, idxptr) + + return builder.load(idxptr) + + def guard_index(self, idx, msg): + """ + Raise an error if the index is out of bounds. + """ + with self._builder.if_then(self.is_out_of_bounds(idx), likely=False): + self._context.call_conv.return_user_exc( + self._builder, IndexError, (msg,) + ) + + def fix_slice(self, slice): + """ + Fix slice start and stop to be valid (inclusive and exclusive, resp) + indexing bounds. + """ + return slicing.fix_slice(self._builder, slice, self.size) + + def incref_value(self, val): + "Incref an element value" + self._context.nrt.incref(self._builder, self.dtype, val) + + def decref_value(self, val): + "Decref an element value" + self._context.nrt.decref(self._builder, self.dtype, val) + + +class ListPayloadAccessor(_ListPayloadMixin): + """ + A helper object to access the list attributes given the pointer to the + payload type. + """ + + def __init__(self, context, builder, list_type, payload_ptr): + self._context = context + self._builder = builder + self._ty = list_type + self._datamodel = context.data_model_manager[list_type.dtype] + payload_type = types.ListPayload(list_type) + ptrty = context.get_data_type(payload_type).as_pointer() + payload_ptr = builder.bitcast(payload_ptr, ptrty) + payload = context.make_data_helper( + builder, payload_type, ref=payload_ptr + ) + self._payload = payload + + +class ListInstance(_ListPayloadMixin): + def __init__(self, context, builder, list_type, list_val): + self._context = context + self._builder = builder + self._ty = list_type + self._list = context.make_helper(builder, list_type, list_val) + self._itemsize = get_itemsize(context, list_type) + self._datamodel = context.data_model_manager[list_type.dtype] + + @property + def dtype(self): + return self._ty.dtype + + @property + def _payload(self): + # This cannot be cached as it can be reallocated + return get_list_payload( + self._context, self._builder, self._ty, self._list + ) + + @property + def parent(self): + return self._list.parent + + @parent.setter + def parent(self, value): + self._list.parent = value + + @property + def value(self): + return self._list._getvalue() + + @property + def meminfo(self): + return self._list.meminfo + + def set_dirty(self, val): + if self._ty.reflected: + self._payload.dirty = cgutils.true_bit if val else cgutils.false_bit + + def clear_value(self, idx): + """Remove the value at the location""" + self.decref_value(self.getitem(idx)) + # it's necessary for the dtor which just decref every slot on it. + self.zfill(idx, self._builder.add(idx, idx.type(1))) + + def setitem(self, idx, val, incref, decref_old_value=True): + # Decref old data + if decref_old_value: + self.decref_value(self.getitem(idx)) + + ptr = self._gep(idx) + data_item = self._datamodel.as_data(self._builder, val) + self._builder.store(data_item, ptr) + self.set_dirty(True) + if incref: + # Incref the underlying data + self.incref_value(val) + + def inititem(self, idx, val, incref=True): + ptr = self._gep(idx) + data_item = self._datamodel.as_data(self._builder, val) + self._builder.store(data_item, ptr) + if incref: + self.incref_value(val) + + def zfill(self, start, stop): + """Zero-fill the memory at index *start* to *stop* + + *stop* MUST not be smaller than *start*. + """ + builder = self._builder + base = self._gep(start) + end = self._gep(stop) + intaddr_t = self._context.get_value_type(types.intp) + size = builder.sub( + builder.ptrtoint(end, intaddr_t), builder.ptrtoint(base, intaddr_t) + ) + cgutils.memset(builder, base, size, ir.IntType(8)(0)) + + @classmethod + def allocate_ex(cls, context, builder, list_type, nitems): + """ + Allocate a ListInstance with its storage. + Return a (ok, instance) tuple where *ok* is a LLVM boolean and + *instance* is a ListInstance object (the object's contents are + only valid when *ok* is true). + """ + intp_t = context.get_value_type(types.intp) + + if isinstance(nitems, int): + nitems = ir.Constant(intp_t, nitems) + + payload_type = context.get_data_type(types.ListPayload(list_type)) + payload_size = context.get_abi_sizeof(payload_type) + + itemsize = get_itemsize(context, list_type) + # Account for the fact that the payload struct contains one entry + payload_size -= itemsize + + ok = cgutils.alloca_once_value(builder, cgutils.true_bit) + self = cls(context, builder, list_type, None) + + # Total allocation size = + nitems * itemsize + allocsize, ovf = cgutils.muladd_with_overflow( + builder, + nitems, + ir.Constant(intp_t, itemsize), + ir.Constant(intp_t, payload_size), + ) + with builder.if_then(ovf, likely=False): + builder.store(cgutils.false_bit, ok) + + with builder.if_then(builder.load(ok), likely=True): + meminfo = context.nrt.meminfo_new_varsize_dtor_unchecked( + builder, size=allocsize, dtor=self.get_dtor() + ) + with builder.if_else( + cgutils.is_null(builder, meminfo), likely=False + ) as (if_error, if_ok): + with if_error: + builder.store(cgutils.false_bit, ok) + with if_ok: + self._list.meminfo = meminfo + self._list.parent = context.get_constant_null( + types.pyobject + ) + self._payload.allocated = nitems + self._payload.size = ir.Constant(intp_t, 0) # for safety + self._payload.dirty = cgutils.false_bit + # Zero the allocated region + self.zfill(self.size.type(0), nitems) + + return builder.load(ok), self + + def define_dtor(self): + "Define the destructor if not already defined" + context = self._context + builder = self._builder + mod = builder.module + # Declare dtor + fnty = ir.FunctionType(ir.VoidType(), [cgutils.voidptr_t]) + fn = cgutils.get_or_insert_function( + mod, fnty, ".dtor.list.{}".format(self.dtype) + ) + if not fn.is_declaration: + # End early if the dtor is already defined + return fn + fn.linkage = "linkonce_odr" + # Populate the dtor + builder = ir.IRBuilder(fn.append_basic_block()) + base_ptr = fn.args[0] # void* + + # get payload + payload = ListPayloadAccessor(context, builder, self._ty, base_ptr) + + # Loop over all data to decref + intp = payload.size.type + with cgutils.for_range_slice( + builder, start=intp(0), stop=payload.size, step=intp(1), intp=intp + ) as (idx, _): + val = payload.getitem(idx) + context.nrt.decref(builder, self.dtype, val) + builder.ret_void() + return fn + + def get_dtor(self): + """ "Get the element dtor function pointer as void pointer. + + It's safe to be called multiple times. + """ + # Define and set the Dtor + dtor = self.define_dtor() + dtor_fnptr = self._builder.bitcast(dtor, cgutils.voidptr_t) + return dtor_fnptr + + @classmethod + def allocate(cls, context, builder, list_type, nitems): + """ + Allocate a ListInstance with its storage. Same as allocate_ex(), + but return an initialized *instance*. If allocation failed, + control is transferred to the caller using the target's current + call convention. + """ + ok, self = cls.allocate_ex(context, builder, list_type, nitems) + with builder.if_then(builder.not_(ok), likely=False): + context.call_conv.return_user_exc( + builder, MemoryError, ("cannot allocate list",) + ) + return self + + @classmethod + def from_meminfo(cls, context, builder, list_type, meminfo): + """ + Allocate a new list instance pointing to an existing payload + (a meminfo pointer). + Note the parent field has to be filled by the caller. + """ + self = cls(context, builder, list_type, None) + self._list.meminfo = meminfo + self._list.parent = context.get_constant_null(types.pyobject) + context.nrt.incref(builder, list_type, self.value) + # Payload is part of the meminfo, no need to touch it + return self + + def resize(self, new_size): + """ + Ensure the list is properly sized for the new size. + """ + + def _payload_realloc(new_allocated): + payload_type = context.get_data_type(types.ListPayload(self._ty)) + payload_size = context.get_abi_sizeof(payload_type) + # Account for the fact that the payload struct contains one entry + payload_size -= itemsize + + allocsize, ovf = cgutils.muladd_with_overflow( + builder, + new_allocated, + ir.Constant(intp_t, itemsize), + ir.Constant(intp_t, payload_size), + ) + with builder.if_then(ovf, likely=False): + context.call_conv.return_user_exc( + builder, MemoryError, ("cannot resize list",) + ) + + ptr = context.nrt.meminfo_varsize_realloc_unchecked( + builder, self._list.meminfo, size=allocsize + ) + cgutils.guard_memory_error( + context, builder, ptr, "cannot resize list" + ) + self._payload.allocated = new_allocated + + context = self._context + builder = self._builder + intp_t = new_size.type + + itemsize = get_itemsize(context, self._ty) + allocated = self._payload.allocated + + two = ir.Constant(intp_t, 2) + eight = ir.Constant(intp_t, 8) + + # allocated < new_size + is_too_small = builder.icmp_signed("<", allocated, new_size) + # (allocated >> 2) > new_size + is_too_large = builder.icmp_signed( + ">", builder.ashr(allocated, two), new_size + ) + + with builder.if_then(is_too_large, likely=False): + # Exact downsize to requested size + # NOTE: is_too_large must be aggressive enough to avoid repeated + # upsizes and downsizes when growing a list. + _payload_realloc(new_size) + + with builder.if_then(is_too_small, likely=False): + # Upsize with moderate over-allocation (size + size >> 2 + 8) + new_allocated = builder.add( + eight, builder.add(new_size, builder.ashr(new_size, two)) + ) + _payload_realloc(new_allocated) + self.zfill(self.size, new_allocated) + + self._payload.size = new_size + self.set_dirty(True) + + def move(self, dest_idx, src_idx, count): + """ + Move `count` elements from `src_idx` to `dest_idx`. + """ + dest_ptr = self._gep(dest_idx) + src_ptr = self._gep(src_idx) + cgutils.raw_memmove( + self._builder, dest_ptr, src_ptr, count, itemsize=self._itemsize + ) + + self.set_dirty(True) + + +class ListIterInstance(_ListPayloadMixin): + def __init__(self, context, builder, iter_type, iter_val): + self._context = context + self._builder = builder + self._ty = iter_type + self._iter = context.make_helper(builder, iter_type, iter_val) + self._datamodel = context.data_model_manager[iter_type.yield_type] + + @classmethod + def from_list(cls, context, builder, iter_type, list_val): + list_inst = ListInstance( + context, builder, iter_type.container, list_val + ) + self = cls(context, builder, iter_type, None) + index = context.get_constant(types.intp, 0) + self._iter.index = cgutils.alloca_once_value(builder, index) + self._iter.meminfo = list_inst.meminfo + return self + + @property + def _payload(self): + # This cannot be cached as it can be reallocated + return get_list_payload( + self._context, self._builder, self._ty.container, self._iter + ) + + @property + def value(self): + return self._iter._getvalue() + + @property + def index(self): + return self._builder.load(self._iter.index) + + @index.setter + def index(self, value): + self._builder.store(value, self._iter.index) + + +# ------------------------------------------------------------------------------- +# Constructors + + +def build_list(context, builder, list_type, items): + """ + Build a list of the given type, containing the given items. + """ + nitems = len(items) + inst = ListInstance.allocate(context, builder, list_type, nitems) + # Populate list + inst.size = context.get_constant(types.intp, nitems) + for i, val in enumerate(items): + inst.setitem(context.get_constant(types.intp, i), val, incref=True) + + return impl_ret_new_ref(context, builder, list_type, inst.value) + + +@lower(list, types.IterableType) +def list_constructor(context, builder, sig, args): + def list_impl(iterable): + res = [] + res.extend(iterable) + return res + + return context.compile_internal(builder, list_impl, sig, args) + + +@lower(list) +def list_constructor(context, builder, sig, args): # noqa: F811 + list_type = sig.return_type + list_len = 0 + inst = ListInstance.allocate(context, builder, list_type, list_len) + return impl_ret_new_ref(context, builder, list_type, inst.value) + + +# ------------------------------------------------------------------------------- +# Various operations + + +@lower(len, types.List) +def list_len(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + return inst.size + + +@lower("getiter", types.List) +def getiter_list(context, builder, sig, args): + inst = ListIterInstance.from_list( + context, builder, sig.return_type, args[0] + ) + return impl_ret_borrowed(context, builder, sig.return_type, inst.value) + + +@lower("iternext", types.ListIter) +@iternext_impl(RefType.BORROWED) +def iternext_listiter(context, builder, sig, args, result): + inst = ListIterInstance(context, builder, sig.args[0], args[0]) + + index = inst.index + nitems = inst.size + is_valid = builder.icmp_signed("<", index, nitems) + result.set_valid(is_valid) + + with builder.if_then(is_valid): + result.yield_(inst.getitem(index)) + inst.index = builder.add(index, context.get_constant(types.intp, 1)) + + +@lower(operator.getitem, types.List, types.Integer) +def getitem_list(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + index = args[1] + + index = inst.fix_index(index) + inst.guard_index(index, msg="getitem out of range") + result = inst.getitem(index) + + return impl_ret_borrowed(context, builder, sig.return_type, result) + + +@lower(operator.setitem, types.List, types.Integer, types.Any) +def setitem_list(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + index = args[1] + value = args[2] + + index = inst.fix_index(index) + inst.guard_index(index, msg="setitem out of range") + inst.setitem(index, value, incref=True) + return context.get_dummy_value() + + +@lower(operator.getitem, types.List, types.SliceType) +def getslice_list(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + slice = context.make_helper(builder, sig.args[1], args[1]) + slicing.guard_invalid_slice(context, builder, sig.args[1], slice) + inst.fix_slice(slice) + + # Allocate result and populate it + result_size = slicing.get_slice_length(builder, slice) + result = ListInstance.allocate( + context, builder, sig.return_type, result_size + ) + result.size = result_size + with cgutils.for_range_slice_generic( + builder, slice.start, slice.stop, slice.step + ) as (pos_range, neg_range): + with pos_range as (idx, count): + value = inst.getitem(idx) + result.inititem(count, value, incref=True) + with neg_range as (idx, count): + value = inst.getitem(idx) + result.inititem(count, value, incref=True) + + return impl_ret_new_ref(context, builder, sig.return_type, result.value) + + +@lower(operator.setitem, types.List, types.SliceType, types.Any) +def setitem_list(context, builder, sig, args): # noqa: F811 + dest = ListInstance(context, builder, sig.args[0], args[0]) + src = ListInstance(context, builder, sig.args[2], args[2]) + + slice = context.make_helper(builder, sig.args[1], args[1]) + slicing.guard_invalid_slice(context, builder, sig.args[1], slice) + dest.fix_slice(slice) + + src_size = src.size + avail_size = slicing.get_slice_length(builder, slice) + size_delta = builder.sub(src.size, avail_size) + + zero = ir.Constant(size_delta.type, 0) + one = ir.Constant(size_delta.type, 1) + + with builder.if_else(builder.icmp_signed("==", slice.step, one)) as ( + then, + otherwise, + ): + with then: + # Slice step == 1 => we can resize + + # Compute the real stop, e.g. for dest[2:0] = [...] + real_stop = builder.add(slice.start, avail_size) + # Size of the list tail, after the end of slice + tail_size = builder.sub(dest.size, real_stop) + + with builder.if_then(builder.icmp_signed(">", size_delta, zero)): + # Grow list then move list tail + dest.resize(builder.add(dest.size, size_delta)) + dest.move( + builder.add(real_stop, size_delta), real_stop, tail_size + ) + + with builder.if_then(builder.icmp_signed("<", size_delta, zero)): + # Move list tail then shrink list + dest.move( + builder.add(real_stop, size_delta), real_stop, tail_size + ) + dest.resize(builder.add(dest.size, size_delta)) + + dest_offset = slice.start + + with cgutils.for_range(builder, src_size) as loop: + value = src.getitem(loop.index) + dest.setitem( + builder.add(loop.index, dest_offset), value, incref=True + ) + + with otherwise: + with builder.if_then(builder.icmp_signed("!=", size_delta, zero)): + msg = "cannot resize extended list slice with step != 1" + context.call_conv.return_user_exc(builder, ValueError, (msg,)) + + with cgutils.for_range_slice_generic( + builder, slice.start, slice.stop, slice.step + ) as (pos_range, neg_range): + with pos_range as (index, count): + value = src.getitem(count) + dest.setitem(index, value, incref=True) + with neg_range as (index, count): + value = src.getitem(count) + dest.setitem(index, value, incref=True) + + return context.get_dummy_value() + + +@lower(operator.delitem, types.List, types.Integer) +def delitem_list_index(context, builder, sig, args): + def list_delitem_impl(lst, i): + lst.pop(i) + + return context.compile_internal(builder, list_delitem_impl, sig, args) + + +@lower(operator.delitem, types.List, types.SliceType) +def delitem_list(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + slice = context.make_helper(builder, sig.args[1], args[1]) + + slicing.guard_invalid_slice(context, builder, sig.args[1], slice) + inst.fix_slice(slice) + + slice_len = slicing.get_slice_length(builder, slice) + + one = ir.Constant(slice_len.type, 1) + + with builder.if_then( + builder.icmp_signed("!=", slice.step, one), likely=False + ): + msg = "unsupported del list[start:stop:step] with step != 1" + context.call_conv.return_user_exc(builder, NotImplementedError, (msg,)) + + # Compute the real stop, e.g. for dest[2:0] + start = slice.start + real_stop = builder.add(start, slice_len) + # Decref the removed range + with cgutils.for_range_slice(builder, start, real_stop, start.type(1)) as ( + idx, + _, + ): + inst.decref_value(inst.getitem(idx)) + + # Size of the list tail, after the end of slice + tail_size = builder.sub(inst.size, real_stop) + inst.move(start, real_stop, tail_size) + inst.resize(builder.sub(inst.size, slice_len)) + + return context.get_dummy_value() + + +# XXX should there be a specific module for Sequence or collection base classes? + + +@lower(operator.contains, types.Sequence, types.Any) +def in_seq(context, builder, sig, args): + def seq_contains_impl(lst, value): + for elem in lst: + if elem == value: + return True + return False + + return context.compile_internal(builder, seq_contains_impl, sig, args) + + +@lower(bool, types.Sequence) +def sequence_bool(context, builder, sig, args): + def sequence_bool_impl(seq): + return len(seq) != 0 + + return context.compile_internal(builder, sequence_bool_impl, sig, args) + + +@overload(operator.truth) +def sequence_truth(seq): + if isinstance(seq, types.Sequence): + + def impl(seq): + return len(seq) != 0 + + return impl + + +@lower(operator.add, types.List, types.List) +def list_add(context, builder, sig, args): + a = ListInstance(context, builder, sig.args[0], args[0]) + b = ListInstance(context, builder, sig.args[1], args[1]) + + a_size = a.size + b_size = b.size + nitems = builder.add(a_size, b_size) + dest = ListInstance.allocate(context, builder, sig.return_type, nitems) + dest.size = nitems + + with cgutils.for_range(builder, a_size) as loop: + value = a.getitem(loop.index) + value = context.cast(builder, value, a.dtype, dest.dtype) + dest.setitem(loop.index, value, incref=True) + with cgutils.for_range(builder, b_size) as loop: + value = b.getitem(loop.index) + value = context.cast(builder, value, b.dtype, dest.dtype) + dest.setitem(builder.add(loop.index, a_size), value, incref=True) + + return impl_ret_new_ref(context, builder, sig.return_type, dest.value) + + +@lower(operator.iadd, types.List, types.List) +def list_add_inplace(context, builder, sig, args): + assert sig.args[0].dtype == sig.return_type.dtype + dest = _list_extend_list(context, builder, sig, args) + + return impl_ret_borrowed(context, builder, sig.return_type, dest.value) + + +@lower(operator.mul, types.List, types.Integer) +@lower(operator.mul, types.Integer, types.List) +def list_mul(context, builder, sig, args): + if isinstance(sig.args[0], types.List): + list_idx, int_idx = 0, 1 + else: + list_idx, int_idx = 1, 0 + src = ListInstance(context, builder, sig.args[list_idx], args[list_idx]) + src_size = src.size + + mult = args[int_idx] + zero = ir.Constant(mult.type, 0) + mult = builder.select(cgutils.is_neg_int(builder, mult), zero, mult) + nitems = builder.mul(mult, src_size) + + dest = ListInstance.allocate(context, builder, sig.return_type, nitems) + dest.size = nitems + + with cgutils.for_range_slice(builder, zero, nitems, src_size, inc=True) as ( + dest_offset, + _, + ): + with cgutils.for_range(builder, src_size) as loop: + value = src.getitem(loop.index) + dest.setitem( + builder.add(loop.index, dest_offset), value, incref=True + ) + + return impl_ret_new_ref(context, builder, sig.return_type, dest.value) + + +@lower(operator.imul, types.List, types.Integer) +def list_mul_inplace(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + src_size = inst.size + + mult = args[1] + zero = ir.Constant(mult.type, 0) + mult = builder.select(cgutils.is_neg_int(builder, mult), zero, mult) + nitems = builder.mul(mult, src_size) + + inst.resize(nitems) + + with cgutils.for_range_slice( + builder, src_size, nitems, src_size, inc=True + ) as (dest_offset, _): + with cgutils.for_range(builder, src_size) as loop: + value = inst.getitem(loop.index) + inst.setitem( + builder.add(loop.index, dest_offset), value, incref=True + ) + + return impl_ret_borrowed(context, builder, sig.return_type, inst.value) + + +# ------------------------------------------------------------------------------- +# Comparisons + + +@lower(operator.is_, types.List, types.List) +def list_is(context, builder, sig, args): + a = ListInstance(context, builder, sig.args[0], args[0]) + b = ListInstance(context, builder, sig.args[1], args[1]) + ma = builder.ptrtoint(a.meminfo, cgutils.intp_t) + mb = builder.ptrtoint(b.meminfo, cgutils.intp_t) + return builder.icmp_signed("==", ma, mb) + + +@lower(operator.eq, types.List, types.List) +def list_eq(context, builder, sig, args): + aty, bty = sig.args + a = ListInstance(context, builder, aty, args[0]) + b = ListInstance(context, builder, bty, args[1]) + + a_size = a.size + same_size = builder.icmp_signed("==", a_size, b.size) + + res = cgutils.alloca_once_value(builder, same_size) + + with builder.if_then(same_size): + with cgutils.for_range(builder, a_size) as loop: + v = a.getitem(loop.index) + w = b.getitem(loop.index) + itemres = context.generic_compare( + builder, operator.eq, (aty.dtype, bty.dtype), (v, w) + ) + with builder.if_then(builder.not_(itemres)): + # Exit early + builder.store(cgutils.false_bit, res) + loop.do_break() + + return builder.load(res) + + +def all_list(*args): + return all([isinstance(typ, types.List) for typ in args]) + + +@overload(operator.ne) +def impl_list_ne(a, b): + if not all_list(a, b): + return + + def list_ne_impl(a, b): + return not (a == b) + + return list_ne_impl + + +@overload(operator.le) +def impl_list_le(a, b): + if not all_list(a, b): + return + + def list_le_impl(a, b): + m = len(a) + n = len(b) + for i in range(min(m, n)): + if a[i] < b[i]: + return True + elif a[i] > b[i]: + return False + return m <= n + + return list_le_impl + + +@overload(operator.lt) +def impl_list_lt(a, b): + if not all_list(a, b): + return + + def list_lt_impl(a, b): + m = len(a) + n = len(b) + for i in range(min(m, n)): + if a[i] < b[i]: + return True + elif a[i] > b[i]: + return False + return m < n + + return list_lt_impl + + +@overload(operator.ge) +def impl_list_ge(a, b): + if not all_list(a, b): + return + + def list_ge_impl(a, b): + return b <= a + + return list_ge_impl + + +@overload(operator.gt) +def impl_list_gt(a, b): + if not all_list(a, b): + return + + def list_gt_impl(a, b): + return b < a + + return list_gt_impl + + +# ------------------------------------------------------------------------------- +# Methods + + +@lower("list.append", types.List, types.Any) +def list_append(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + item = args[1] + + n = inst.size + new_size = builder.add(n, ir.Constant(n.type, 1)) + inst.resize(new_size) + inst.setitem(n, item, incref=True) + + return context.get_dummy_value() + + +@lower("list.clear", types.List) +def list_clear(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + inst.resize(context.get_constant(types.intp, 0)) + + return context.get_dummy_value() + + +@overload_method(types.List, "copy") +def list_copy(lst): + def list_copy_impl(lst): + return list(lst) + + return list_copy_impl + + +@overload_method(types.List, "count") +def list_count(lst, value): + def list_count_impl(lst, value): + res = 0 + for elem in lst: + if elem == value: + res += 1 + return res + + return list_count_impl + + +def _list_extend_list(context, builder, sig, args): + src = ListInstance(context, builder, sig.args[1], args[1]) + dest = ListInstance(context, builder, sig.args[0], args[0]) + + src_size = src.size + dest_size = dest.size + nitems = builder.add(src_size, dest_size) + dest.resize(nitems) + dest.size = nitems + + with cgutils.for_range(builder, src_size) as loop: + value = src.getitem(loop.index) + value = context.cast(builder, value, src.dtype, dest.dtype) + dest.setitem(builder.add(loop.index, dest_size), value, incref=True) + + return dest + + +@lower("list.extend", types.List, types.IterableType) +def list_extend(context, builder, sig, args): + if isinstance(sig.args[1], types.List): + # Specialize for list operands, for speed. + _list_extend_list(context, builder, sig, args) + return context.get_dummy_value() + + def list_extend(lst, iterable): + # Speed hack to avoid NRT refcount operations inside the loop + meth = lst.append + for v in iterable: + meth(v) + + return context.compile_internal(builder, list_extend, sig, args) + + +intp_max = types.intp.maxval + + +@overload_method(types.List, "index") +def list_index(lst, value, start=0, stop=intp_max): + if not isinstance(start, (int, types.Integer, types.Omitted)): + raise errors.TypingError(f'arg "start" must be an Integer. Got {start}') + if not isinstance(stop, (int, types.Integer, types.Omitted)): + raise errors.TypingError(f'arg "stop" must be an Integer. Got {stop}') + + def list_index_impl(lst, value, start=0, stop=intp_max): + n = len(lst) + if start < 0: + start += n + if start < 0: + start = 0 + if stop < 0: + stop += n + if stop > n: + stop = n + for i in range(start, stop): + if lst[i] == value: + return i + # XXX references are leaked when raising + raise ValueError("value not in list") + + return list_index_impl + + +@lower("list.insert", types.List, types.Integer, types.Any) +def list_insert(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + index = inst.fix_index(args[1]) + index = inst.clamp_index(index) + value = args[2] + + n = inst.size + one = ir.Constant(n.type, 1) + new_size = builder.add(n, one) + inst.resize(new_size) + inst.move(builder.add(index, one), index, builder.sub(n, index)) + inst.setitem(index, value, incref=True, decref_old_value=False) + + return context.get_dummy_value() + + +@lower("list.pop", types.List) +def list_pop(context, builder, sig, args): + inst = ListInstance(context, builder, sig.args[0], args[0]) + + n = inst.size + cgutils.guard_zero(context, builder, n, (IndexError, "pop from empty list")) + n = builder.sub(n, ir.Constant(n.type, 1)) + res = inst.getitem(n) + inst.incref_value(res) # incref the pop'ed element + inst.clear_value(n) # clear the storage space + inst.resize(n) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +@lower("list.pop", types.List, types.Integer) +def list_pop(context, builder, sig, args): # noqa: F811 + inst = ListInstance(context, builder, sig.args[0], args[0]) + idx = inst.fix_index(args[1]) + + n = inst.size + cgutils.guard_zero(context, builder, n, (IndexError, "pop from empty list")) + inst.guard_index(idx, "pop index out of range") + + res = inst.getitem(idx) + + one = ir.Constant(n.type, 1) + n = builder.sub(n, ir.Constant(n.type, 1)) + inst.move(idx, builder.add(idx, one), builder.sub(n, idx)) + inst.resize(n) + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +@overload_method(types.List, "remove") +def list_remove(lst, value): + def list_remove_impl(lst, value): + for i in range(len(lst)): + if lst[i] == value: + lst.pop(i) + return + # XXX references are leaked when raising + raise ValueError("list.remove(x): x not in list") + + return list_remove_impl + + +@overload_method(types.List, "reverse") +def list_reverse(lst): + def list_reverse_impl(lst): + for a in range(0, len(lst) // 2): + b = -a - 1 + lst[a], lst[b] = lst[b], lst[a] + + return list_reverse_impl + + +# ----------------------------------------------------------------------------- +# Sorting + + +def gt(a, b): + return a > b + + +sort_forwards = quicksort.make_jit_quicksort().run_quicksort +sort_backwards = quicksort.make_jit_quicksort(lt=gt).run_quicksort + +arg_sort_forwards = quicksort.make_jit_quicksort( + is_argsort=True, is_list=True +).run_quicksort +arg_sort_backwards = quicksort.make_jit_quicksort( + is_argsort=True, lt=gt, is_list=True +).run_quicksort + + +def _sort_check_reverse(reverse): + if isinstance(reverse, types.Omitted): + rty = reverse.value + elif isinstance(reverse, types.Optional): + rty = reverse.type + else: + rty = reverse + if not isinstance(rty, (types.Boolean, types.Integer, int, bool)): + msg = "an integer is required for 'reverse' (got type %s)" % reverse + raise errors.TypingError(msg) + return rty + + +def _sort_check_key(key): + if isinstance(key, types.Optional): + msg = ( + "Key must concretely be None or a Numba JIT compiled function, " + "an Optional (union of None and a value) was found" + ) + raise errors.TypingError(msg) + if not (cgutils.is_nonelike(key) or isinstance(key, types.Dispatcher)): + msg = "Key must be None or a Numba JIT compiled function" + raise errors.TypingError(msg) + + +@overload_method(types.List, "sort") +def ol_list_sort(lst, key=None, reverse=False): + _sort_check_key(key) + _sort_check_reverse(reverse) + + if cgutils.is_nonelike(key): + KEY = False + sort_f = sort_forwards + sort_b = sort_backwards + elif isinstance(key, types.Dispatcher): + KEY = True + sort_f = arg_sort_forwards + sort_b = arg_sort_backwards + + def impl(lst, key=None, reverse=False): + if KEY is True: + _lst = [key(x) for x in lst] + else: + _lst = lst + if reverse is False or reverse == 0: + tmp = sort_f(_lst) + else: + tmp = sort_b(_lst) + if KEY is True: + lst[:] = [lst[i] for i in tmp] + + return impl + + +@overload(sorted) +def ol_sorted(iterable, key=None, reverse=False): + if not isinstance(iterable, types.IterableType): + return False + + _sort_check_key(key) + _sort_check_reverse(reverse) + + def impl(iterable, key=None, reverse=False): + lst = list(iterable) + lst.sort(key=key, reverse=reverse) + return lst + + return impl + + +# ----------------------------------------------------------------------------- +# Implicit casting + + +@lower_cast(types.List, types.List) +def list_to_list(context, builder, fromty, toty, val): + # Casting from non-reflected to reflected + assert fromty.dtype == toty.dtype + return val + + +# ----------------------------------------------------------------------------- +# Implementations for types.LiteralList +# ----------------------------------------------------------------------------- + +_banned_error = errors.TypingError("Cannot mutate a literal list") + + +# Things that mutate literal lists are banned +@overload_method(types.LiteralList, "append") +def literal_list_banned_append(lst, obj): + raise _banned_error + + +@overload_method(types.LiteralList, "extend") +def literal_list_banned_extend(lst, iterable): + raise _banned_error + + +@overload_method(types.LiteralList, "insert") +def literal_list_banned_insert(lst, index, obj): + raise _banned_error + + +@overload_method(types.LiteralList, "remove") +def literal_list_banned_remove(lst, value): + raise _banned_error + + +@overload_method(types.LiteralList, "pop") +def literal_list_banned_pop(lst, index=-1): + raise _banned_error + + +@overload_method(types.LiteralList, "clear") +def literal_list_banned_clear(lst): + raise _banned_error + + +@overload_method(types.LiteralList, "sort") +def literal_list_banned_sort(lst, key=None, reverse=False): + raise _banned_error + + +@overload_method(types.LiteralList, "reverse") +def literal_list_banned_reverse(lst): + raise _banned_error + + +_index_end = types.intp.maxval + + +@overload_method(types.LiteralList, "index") +def literal_list_index(lst, x, start=0, end=_index_end): + # TODO: To make this work, need consts as slice for start/end so as to + # be able to statically analyse the bounds, then its a just loop body + # versioning based iteration along with enumerate to find the item + if isinstance(lst, types.LiteralList): + msg = "list.index is unsupported for literal lists" + raise errors.TypingError(msg) + + +@overload_method(types.LiteralList, "count") +def literal_list_count(lst, x): + if isinstance(lst, types.LiteralList): + + def impl(lst, x): + count = 0 + for val in literal_unroll(lst): + if val == x: + count += 1 + return count + + return impl + + +@overload_method(types.LiteralList, "copy") +def literal_list_count(lst): # noqa: F811 + if isinstance(lst, types.LiteralList): + + def impl(lst): + return lst # tuples are immutable, as is this, so just return it + + return impl + + +@overload(operator.delitem) +def literal_list_delitem(lst, index): + if isinstance(lst, types.LiteralList): + raise _banned_error + + +@overload(operator.setitem) +def literal_list_setitem(lst, index, value): + if isinstance(lst, types.LiteralList): + raise errors.TypingError("Cannot mutate a literal list") + + +@overload(operator.getitem) +def literal_list_getitem(lst, *args): + if not isinstance(lst, types.LiteralList): + return + msg = ( + "Cannot __getitem__ on a literal list, return type cannot be " + "statically determined." + ) + raise errors.TypingError(msg) + + +@overload(len) +def literal_list_len(lst): + if not isinstance(lst, types.LiteralList): + return + l = lst.count + return lambda lst: l + + +@overload(operator.contains) +def literal_list_contains(lst, item): + if isinstance(lst, types.LiteralList): + + def impl(lst, item): + for val in literal_unroll(lst): + if val == item: + return True + return False + + return impl + + +@lower_cast(types.LiteralList, types.LiteralList) +def literallist_to_literallist(context, builder, fromty, toty, val): + if len(fromty) != len(toty): + # Disallowed by typing layer + raise NotImplementedError + + olditems = cgutils.unpack_tuple(builder, val, len(fromty)) + items = [ + context.cast(builder, v, f, t) + for v, f, t in zip(olditems, fromty, toty) + ] + return context.make_tuple(builder, toty, items) diff --git a/numba_cuda/numba/cuda/cpython/slicing.py b/numba_cuda/numba/cuda/cpython/slicing.py new file mode 100644 index 000000000..f5ad77482 --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/slicing.py @@ -0,0 +1,322 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +Implement slices and various slice computations. +""" + +from llvmlite import ir +from numba.core import types +from numba.cuda import cgutils +from numba.core.imputils import impl_ret_untracked, Registry + +registry = Registry("slicing") +lower = registry.lower +lower_getattr = registry.lower_getattr +lower_cast = registry.lower_cast +lower_constant = registry.lower_constant + + +def fix_index(builder, idx, size): + """ + Fix negative index by adding *size* to it. Positive + indices are left untouched. + """ + is_negative = builder.icmp_signed("<", idx, ir.Constant(size.type, 0)) + wrapped_index = builder.add(idx, size) + return builder.select(is_negative, wrapped_index, idx) + + +def fix_slice(builder, slice, size): + """ + Fix *slice* start and stop to be valid (inclusive and exclusive, resp) + indexing bounds for a sequence of the given *size*. + """ + # See PySlice_GetIndicesEx() + zero = ir.Constant(size.type, 0) + minus_one = ir.Constant(size.type, -1) + + def fix_bound(bound_name, lower_repl, upper_repl): + bound = getattr(slice, bound_name) + bound = fix_index(builder, bound, size) + # Store value + setattr(slice, bound_name, bound) + # Still negative? => clamp to lower_repl + underflow = builder.icmp_signed("<", bound, zero) + with builder.if_then(underflow, likely=False): + setattr(slice, bound_name, lower_repl) + # Greater than size? => clamp to upper_repl + overflow = builder.icmp_signed(">=", bound, size) + with builder.if_then(overflow, likely=False): + setattr(slice, bound_name, upper_repl) + + with builder.if_else(cgutils.is_neg_int(builder, slice.step)) as ( + if_neg_step, + if_pos_step, + ): + with if_pos_step: + # < 0 => 0; >= size => size + fix_bound("start", zero, size) + fix_bound("stop", zero, size) + with if_neg_step: + # < 0 => -1; >= size => size - 1 + lower = minus_one + upper = builder.add(size, minus_one) + fix_bound("start", lower, upper) + fix_bound("stop", lower, upper) + + +def get_slice_length(builder, slicestruct): + """ + Given a slice, compute the number of indices it spans, i.e. the + number of iterations that for_range_slice() will execute. + + Pseudo-code: + assert step != 0 + if step > 0: + if stop <= start: + return 0 + else: + return (stop - start - 1) // step + 1 + else: + if stop >= start: + return 0 + else: + return (stop - start + 1) // step + 1 + + (see PySlice_GetIndicesEx() in CPython) + """ + start = slicestruct.start + stop = slicestruct.stop + step = slicestruct.step + one = ir.Constant(start.type, 1) + zero = ir.Constant(start.type, 0) + + is_step_negative = cgutils.is_neg_int(builder, step) + delta = builder.sub(stop, start) + + # Nominal case + pos_dividend = builder.sub(delta, one) + neg_dividend = builder.add(delta, one) + dividend = builder.select(is_step_negative, neg_dividend, pos_dividend) + nominal_length = builder.add(one, builder.sdiv(dividend, step)) + + # Catch zero length + is_zero_length = builder.select( + is_step_negative, + builder.icmp_signed(">=", delta, zero), + builder.icmp_signed("<=", delta, zero), + ) + + # Clamp to 0 if is_zero_length + return builder.select(is_zero_length, zero, nominal_length) + + +def get_slice_bounds(builder, slicestruct): + """ + Return the [lower, upper) indexing bounds of a slice. + """ + start = slicestruct.start + stop = slicestruct.stop + zero = start.type(0) + one = start.type(1) + # This is a bit pessimal, e.g. it will return [1, 5) instead + # of [1, 4) for `1:5:2` + is_step_negative = builder.icmp_signed("<", slicestruct.step, zero) + lower = builder.select(is_step_negative, builder.add(stop, one), start) + upper = builder.select(is_step_negative, builder.add(start, one), stop) + return lower, upper + + +def fix_stride(builder, slice, stride): + """ + Fix the given stride for the slice's step. + """ + return builder.mul(slice.step, stride) + + +def guard_invalid_slice(context, builder, typ, slicestruct): + """ + Guard against *slicestruct* having a zero step (and raise ValueError). + """ + if typ.has_step: + cgutils.guard_null( + context, + builder, + slicestruct.step, + (ValueError, "slice step cannot be zero"), + ) + + +def get_defaults(context): + """ + Get the default values for a slice's members: + (start for positive step, start for negative step, + stop for positive step, stop for negative step, step) + """ + maxint = (1 << (context.address_size - 1)) - 1 + return (0, maxint, maxint, -maxint - 1, 1) + + +# --------------------------------------------------------------------------- +# The slice structure + + +@lower(slice, types.VarArg(types.Any)) +def slice_constructor_impl(context, builder, sig, args): + ( + default_start_pos, + default_start_neg, + default_stop_pos, + default_stop_neg, + default_step, + ) = [context.get_constant(types.intp, x) for x in get_defaults(context)] + + slice_args = [None] * 3 + + # Fetch non-None arguments + if len(args) == 1 and sig.args[0] is not types.none: + slice_args[1] = args[0] + else: + for i, (ty, val) in enumerate(zip(sig.args, args)): + if ty is not types.none: + slice_args[i] = val + + # Fill omitted arguments + def get_arg_value(i, default): + val = slice_args[i] + if val is None: + return default + else: + return val + + step = get_arg_value(2, default_step) + is_step_negative = builder.icmp_signed( + "<", step, context.get_constant(types.intp, 0) + ) + default_stop = builder.select( + is_step_negative, default_stop_neg, default_stop_pos + ) + default_start = builder.select( + is_step_negative, default_start_neg, default_start_pos + ) + stop = get_arg_value(1, default_stop) + start = get_arg_value(0, default_start) + + ty = sig.return_type + sli = context.make_helper(builder, sig.return_type) + sli.start = start + sli.stop = stop + sli.step = step + + res = sli._getvalue() + return impl_ret_untracked(context, builder, sig.return_type, res) + + +@lower_getattr(types.SliceType, "start") +def slice_start_impl(context, builder, typ, value): + sli = context.make_helper(builder, typ, value) + return sli.start + + +@lower_getattr(types.SliceType, "stop") +def slice_stop_impl(context, builder, typ, value): + sli = context.make_helper(builder, typ, value) + return sli.stop + + +@lower_getattr(types.SliceType, "step") +def slice_step_impl(context, builder, typ, value): + if typ.has_step: + sli = context.make_helper(builder, typ, value) + return sli.step + else: + return context.get_constant(types.intp, 1) + + +@lower("slice.indices", types.SliceType, types.Integer) +def slice_indices(context, builder, sig, args): + length = args[1] + sli = context.make_helper(builder, sig.args[0], args[0]) + + with builder.if_then(cgutils.is_neg_int(builder, length), likely=False): + context.call_conv.return_user_exc( + builder, ValueError, ("length should not be negative",) + ) + with builder.if_then( + cgutils.is_scalar_zero(builder, sli.step), likely=False + ): + context.call_conv.return_user_exc( + builder, ValueError, ("slice step cannot be zero",) + ) + + fix_slice(builder, sli, length) + + return context.make_tuple( + builder, sig.return_type, (sli.start, sli.stop, sli.step) + ) + + +def make_slice_from_constant(context, builder, ty, pyval): + sli = context.make_helper(builder, ty) + lty = context.get_value_type(types.intp) + + ( + default_start_pos, + default_start_neg, + default_stop_pos, + default_stop_neg, + default_step, + ) = [context.get_constant(types.intp, x) for x in get_defaults(context)] + + step = pyval.step + if step is None: + step_is_neg = False + step = default_step + else: + step_is_neg = step < 0 + step = lty(step) + + start = pyval.start + if start is None: + if step_is_neg: + start = default_start_neg + else: + start = default_start_pos + else: + start = lty(start) + + stop = pyval.stop + if stop is None: + if step_is_neg: + stop = default_stop_neg + else: + stop = default_stop_pos + else: + stop = lty(stop) + + sli.start = start + sli.stop = stop + sli.step = step + + return sli._getvalue() + + +@lower_constant(types.SliceType) +def constant_slice(context, builder, ty, pyval): + if isinstance(ty, types.Literal): + typ = ty.literal_type + else: + typ = ty + + return make_slice_from_constant(context, builder, typ, pyval) + + +@lower_cast(types.misc.SliceLiteral, types.SliceType) +def cast_from_literal(context, builder, fromty, toty, val): + return make_slice_from_constant( + context, + builder, + toty, + fromty.literal_value, + ) diff --git a/numba_cuda/numba/cuda/cpython/unicode.py b/numba_cuda/numba/cuda/cpython/unicode.py new file mode 100644 index 000000000..4b9d6535a --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/unicode.py @@ -0,0 +1,2863 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +import sys +import operator + +import numpy as np +from llvmlite.ir import IntType, Constant + +from numba.cuda.cgutils import is_nonelike +from numba.core.extending import ( + NativeValue, + overload, + overload_method, + register_jitable, + models, +) +from numba.cuda.core.pythonapi import box, unbox +from numba.cuda.extending import make_attribute_wrapper, intrinsic +from numba.cuda.models import register_model +from numba.core.imputils import ( + iternext_impl, + impl_ret_new_ref, + RefType, + Registry, +) +from numba.core.datamodel import register_default, StructModel +from numba.core import types +from numba.cuda import cgutils +from numba.cuda.utils import PYVERSION +from numba.cuda.core.pythonapi import ( + PY_UNICODE_1BYTE_KIND, + PY_UNICODE_2BYTE_KIND, + PY_UNICODE_4BYTE_KIND, +) +from numba._helperlib import c_helpers +from numba.cpython.hashing import _Py_hash_t +from numba.cuda.core.unsafe.bytes import memcpy_region +from numba.core.errors import TypingError +from numba.cuda.cpython.unicode_support import ( + _Py_TOUPPER, + _Py_TOLOWER, + _Py_UCS4, + _Py_ISALNUM, + _PyUnicode_ToUpperFull, + _PyUnicode_ToLowerFull, + _PyUnicode_ToFoldedFull, + _PyUnicode_ToTitleFull, + _PyUnicode_IsPrintable, + _PyUnicode_IsSpace, + _Py_ISSPACE, + _PyUnicode_IsXidStart, + _PyUnicode_IsXidContinue, + _PyUnicode_IsCased, + _PyUnicode_IsCaseIgnorable, + _PyUnicode_IsUppercase, + _PyUnicode_IsLowercase, + _PyUnicode_IsLineBreak, + _Py_ISLINEBREAK, + _Py_ISLINEFEED, + _Py_ISCARRIAGERETURN, + _PyUnicode_IsTitlecase, + _Py_ISLOWER, + _Py_ISUPPER, + _Py_TAB, + _Py_LINEFEED, + _Py_CARRIAGE_RETURN, + _Py_SPACE, + _PyUnicode_IsAlpha, + _PyUnicode_IsNumeric, + _Py_ISALPHA, + _PyUnicode_IsDigit, + _PyUnicode_IsDecimalDigit, +) +from numba.cuda.cpython import slicing + +registry = Registry("unicode") +lower = registry.lower +lower_cast = registry.lower_cast +lower_constant = registry.lower_constant +lower_getattr = registry.lower_getattr + +if PYVERSION in ((3, 9), (3, 10), (3, 11)): + from numba.core.pythonapi import PY_UNICODE_WCHAR_KIND + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L84-L85 # noqa: E501 +_MAX_UNICODE = 0x10FFFF + +# https://github.com/python/cpython/blob/1960eb005e04b7ad8a91018088cfdb0646bc1ca0/Objects/stringlib/fastsearch.h#L31 # noqa: E501 +_BLOOM_WIDTH = types.intp.bitwidth + +# DATA MODEL + + +@register_model(types.UnicodeType) +class UnicodeModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", types.voidptr), + ("length", types.intp), + ("kind", types.int32), + ("is_ascii", types.uint32), + ("hash", _Py_hash_t), + ("meminfo", types.MemInfoPointer(types.voidptr)), + # A pointer to the owner python str/unicode object + ("parent", types.pyobject), + ] + models.StructModel.__init__(self, dmm, fe_type, members) + + +make_attribute_wrapper(types.UnicodeType, "data", "_data") +make_attribute_wrapper(types.UnicodeType, "length", "_length") +make_attribute_wrapper(types.UnicodeType, "kind", "_kind") +make_attribute_wrapper(types.UnicodeType, "is_ascii", "_is_ascii") +make_attribute_wrapper(types.UnicodeType, "hash", "_hash") + + +@register_default(types.UnicodeIteratorType) +class UnicodeIteratorModel(StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("index", types.EphemeralPointer(types.uintp)), + ("data", fe_type.data), + ] + super(UnicodeIteratorModel, self).__init__(dmm, fe_type, members) + + +# CAST + + +def compile_time_get_string_data(obj): + """Get string data from a python string for use at compile-time to embed + the string data into the LLVM module. + """ + from ctypes import ( + CFUNCTYPE, + c_void_p, + c_int, + c_uint, + c_ssize_t, + c_ubyte, + py_object, + POINTER, + byref, + ) + + extract_unicode_fn = c_helpers["extract_unicode"] + proto = CFUNCTYPE( + c_void_p, + py_object, + POINTER(c_ssize_t), + POINTER(c_int), + POINTER(c_uint), + POINTER(c_ssize_t), + ) + fn = proto(extract_unicode_fn) + length = c_ssize_t() + kind = c_int() + is_ascii = c_uint() + hashv = c_ssize_t() + data = fn(obj, byref(length), byref(kind), byref(is_ascii), byref(hashv)) + if data is None: + raise ValueError("cannot extract unicode data from the given string") + length = length.value + kind = kind.value + is_ascii = is_ascii.value + nbytes = (length + 1) * _kind_to_byte_width(kind) + out = (c_ubyte * nbytes).from_address(data) + return bytes(out), length, kind, is_ascii, hashv.value + + +def make_string_from_constant(context, builder, typ, literal_string): + """ + Get string data by `compile_time_get_string_data()` and return a + unicode_type LLVM value + """ + databytes, length, kind, is_ascii, hashv = compile_time_get_string_data( + literal_string + ) + mod = builder.module + gv = context.insert_const_bytes(mod, databytes) + uni_str = cgutils.create_struct_proxy(typ)(context, builder) + uni_str.data = gv + uni_str.length = uni_str.length.type(length) + uni_str.kind = uni_str.kind.type(kind) + uni_str.is_ascii = uni_str.is_ascii.type(is_ascii) + # Set hash to -1 to indicate that it should be computed. + # We cannot bake in the hash value because of hashseed randomization. + uni_str.hash = uni_str.hash.type(-1) + return uni_str._getvalue() + + +@lower_cast(types.StringLiteral, types.unicode_type) +def cast_from_literal(context, builder, fromty, toty, val): + return make_string_from_constant( + context, + builder, + toty, + fromty.literal_value, + ) + + +# CONSTANT + + +@lower_constant(types.unicode_type) +def constant_unicode(context, builder, typ, pyval): + return make_string_from_constant(context, builder, typ, pyval) + + +# BOXING + + +@unbox(types.UnicodeType) +def unbox_unicode_str(typ, obj, c): + """ + Convert a unicode str object to a native unicode structure. + """ + ok, data, length, kind, is_ascii, hashv = ( + c.pyapi.string_as_string_size_and_kind(obj) + ) + uni_str = cgutils.create_struct_proxy(typ)(c.context, c.builder) + uni_str.data = data + uni_str.length = length + uni_str.kind = kind + uni_str.is_ascii = is_ascii + uni_str.hash = hashv + uni_str.meminfo = c.pyapi.nrt_meminfo_new_from_pyobject( + data, # the borrowed data pointer + obj, # the owner pyobject; the call will incref it. + ) + uni_str.parent = obj + + is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred()) + return NativeValue(uni_str._getvalue(), is_error=is_error) + + +@box(types.UnicodeType) +def box_unicode_str(typ, val, c): + """ + Convert a native unicode structure to a unicode string + """ + uni_str = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) + res = c.pyapi.string_from_kind_and_data( + uni_str.kind, uni_str.data, uni_str.length + ) + # hash isn't needed now, just compute it so it ends up in the unicodeobject + # hash cache, cpython doesn't always do this, depends how a string was + # created it's safe, just burns the cycles required to hash on @box + c.pyapi.object_hash(res) + c.context.nrt.decref(c.builder, typ, val) + return res + + +# HELPER FUNCTIONS + + +def make_deref_codegen(bitsize): + def codegen(context, builder, signature, args): + data, idx = args + ptr = builder.bitcast(data, IntType(bitsize).as_pointer()) + ch = builder.load(builder.gep(ptr, [idx])) + return builder.zext(ch, IntType(32)) + + return codegen + + +@intrinsic +def deref_uint8(typingctx, data, offset): + sig = types.uint32(types.voidptr, types.intp) + return sig, make_deref_codegen(8) + + +@intrinsic +def deref_uint16(typingctx, data, offset): + sig = types.uint32(types.voidptr, types.intp) + return sig, make_deref_codegen(16) + + +@intrinsic +def deref_uint32(typingctx, data, offset): + sig = types.uint32(types.voidptr, types.intp) + return sig, make_deref_codegen(32) + + +@intrinsic +def _malloc_string(typingctx, kind, char_bytes, length, is_ascii): + """make empty string with data buffer of size alloc_bytes. + + Must set length and kind values for string after it is returned + """ + + def details(context, builder, signature, args): + [kind_val, char_bytes_val, length_val, is_ascii_val] = args + + # fill the struct + uni_str_ctor = cgutils.create_struct_proxy(types.unicode_type) + uni_str = uni_str_ctor(context, builder) + # add null padding character + nbytes_val = builder.mul( + char_bytes_val, + builder.add(length_val, Constant(length_val.type, 1)), + ) + uni_str.meminfo = context.nrt.meminfo_alloc(builder, nbytes_val) + uni_str.kind = kind_val + uni_str.is_ascii = is_ascii_val + uni_str.length = length_val + # empty string has hash value -1 to indicate "need to compute hash" + uni_str.hash = context.get_constant(_Py_hash_t, -1) + uni_str.data = context.nrt.meminfo_data(builder, uni_str.meminfo) + # Set parent to NULL + uni_str.parent = cgutils.get_null_value(uni_str.parent.type) + return uni_str._getvalue() + + sig = types.unicode_type(types.int32, types.intp, types.intp, types.uint32) + return sig, details + + +@register_jitable +def _empty_string(kind, length, is_ascii=0): + char_width = _kind_to_byte_width(kind) + s = _malloc_string(kind, char_width, length, is_ascii) + _set_code_point(s, length, np.uint32(0)) # Write NULL character + return s + + +# Disable RefCt for performance. +@register_jitable(_nrt=False) +def _get_code_point(a, i): + if a._kind == PY_UNICODE_1BYTE_KIND: + return deref_uint8(a._data, i) + elif a._kind == PY_UNICODE_2BYTE_KIND: + return deref_uint16(a._data, i) + elif a._kind == PY_UNICODE_4BYTE_KIND: + return deref_uint32(a._data, i) + else: + # there's also a wchar kind, but that's one of the above, + # so skipping for this example + return 0 + + +#### + + +def make_set_codegen(bitsize): + def codegen(context, builder, signature, args): + data, idx, ch = args + if bitsize < 32: + ch = builder.trunc(ch, IntType(bitsize)) + ptr = builder.bitcast(data, IntType(bitsize).as_pointer()) + builder.store(ch, builder.gep(ptr, [idx])) + return context.get_dummy_value() + + return codegen + + +@intrinsic +def set_uint8(typingctx, data, idx, ch): + sig = types.void(types.voidptr, types.int64, types.uint32) + return sig, make_set_codegen(8) + + +@intrinsic +def set_uint16(typingctx, data, idx, ch): + sig = types.void(types.voidptr, types.int64, types.uint32) + return sig, make_set_codegen(16) + + +@intrinsic +def set_uint32(typingctx, data, idx, ch): + sig = types.void(types.voidptr, types.int64, types.uint32) + return sig, make_set_codegen(32) + + +@register_jitable(_nrt=False) +def _set_code_point(a, i, ch): + # WARNING: This method is very dangerous: + # * Assumes that data contents can be changed (only allowed for new + # strings) + # * Assumes that the kind of unicode string is sufficiently wide to + # accept ch. Will truncate ch to make it fit. + # * Assumes that i is within the valid boundaries of the function + if a._kind == PY_UNICODE_1BYTE_KIND: + set_uint8(a._data, i, ch) + elif a._kind == PY_UNICODE_2BYTE_KIND: + set_uint16(a._data, i, ch) + elif a._kind == PY_UNICODE_4BYTE_KIND: + set_uint32(a._data, i, ch) + else: + raise AssertionError( + "Unexpected unicode representation in _set_code_point" + ) + + +if PYVERSION in ((3, 12), (3, 13)): + + @register_jitable + def _pick_kind(kind1, kind2): + if kind1 == PY_UNICODE_1BYTE_KIND: + return kind2 + elif kind1 == PY_UNICODE_2BYTE_KIND: + if kind2 == PY_UNICODE_4BYTE_KIND: + return kind2 + else: + return kind1 + elif kind1 == PY_UNICODE_4BYTE_KIND: + return kind1 + else: + raise AssertionError( + "Unexpected unicode representation in _pick_kind" + ) +elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + + @register_jitable + def _pick_kind(kind1, kind2): + if kind1 == PY_UNICODE_WCHAR_KIND or kind2 == PY_UNICODE_WCHAR_KIND: + raise AssertionError("PY_UNICODE_WCHAR_KIND unsupported") + + if kind1 == PY_UNICODE_1BYTE_KIND: + return kind2 + elif kind1 == PY_UNICODE_2BYTE_KIND: + if kind2 == PY_UNICODE_4BYTE_KIND: + return kind2 + else: + return kind1 + elif kind1 == PY_UNICODE_4BYTE_KIND: + return kind1 + else: + raise AssertionError( + "Unexpected unicode representation in _pick_kind" + ) +else: + raise NotImplementedError(PYVERSION) + + +@register_jitable +def _pick_ascii(is_ascii1, is_ascii2): + if is_ascii1 == 1 and is_ascii2 == 1: + return types.uint32(1) + return types.uint32(0) + + +if PYVERSION in ((3, 12), (3, 13)): + + @register_jitable + def _kind_to_byte_width(kind): + if kind == PY_UNICODE_1BYTE_KIND: + return 1 + elif kind == PY_UNICODE_2BYTE_KIND: + return 2 + elif kind == PY_UNICODE_4BYTE_KIND: + return 4 + else: + raise AssertionError("Unexpected unicode encoding encountered") +elif PYVERSION in ((3, 9), (3, 10), (3, 11)): + + @register_jitable + def _kind_to_byte_width(kind): + if kind == PY_UNICODE_1BYTE_KIND: + return 1 + elif kind == PY_UNICODE_2BYTE_KIND: + return 2 + elif kind == PY_UNICODE_4BYTE_KIND: + return 4 + elif kind == PY_UNICODE_WCHAR_KIND: + raise AssertionError("PY_UNICODE_WCHAR_KIND unsupported") + else: + raise AssertionError("Unexpected unicode encoding encountered") +else: + raise NotImplementedError(PYVERSION) + + +@register_jitable(_nrt=False) +def _cmp_region(a, a_offset, b, b_offset, n): + if n == 0: + return 0 + elif a_offset + n > a._length: + return -1 + elif b_offset + n > b._length: + return 1 + + for i in range(n): + a_chr = _get_code_point(a, a_offset + i) + b_chr = _get_code_point(b, b_offset + i) + if a_chr < b_chr: + return -1 + elif a_chr > b_chr: + return 1 + + return 0 + + +@register_jitable +def _codepoint_to_kind(cp): + """ + Compute the minimum unicode kind needed to hold a given codepoint + """ + if cp < 256: + return PY_UNICODE_1BYTE_KIND + elif cp < 65536: + return PY_UNICODE_2BYTE_KIND + else: + # Maximum code point of Unicode 6.0: 0x10ffff (1,114,111) + MAX_UNICODE = 0x10FFFF + if cp > MAX_UNICODE: + msg = "Invalid codepoint. Found value greater than Unicode maximum" + raise ValueError(msg) + return PY_UNICODE_4BYTE_KIND + + +@register_jitable +def _codepoint_is_ascii(ch): + """ + Returns true if a codepoint is in the ASCII range + """ + return ch < 128 + + +# PUBLIC API + + +@overload(len) +def unicode_len(s): + if isinstance(s, types.UnicodeType): + + def len_impl(s): + return s._length + + return len_impl + + +@overload(operator.eq) +def unicode_eq(a, b): + if not (a.is_internal and b.is_internal): + return + if isinstance(a, types.Optional): + check_a = a.type + else: + check_a = a + if isinstance(b, types.Optional): + check_b = b.type + else: + check_b = b + accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq) + a_unicode = isinstance(check_a, accept) + b_unicode = isinstance(check_b, accept) + if a_unicode and b_unicode: + + def eq_impl(a, b): + # handle Optionals at runtime + a_none = a is None + b_none = b is None + if a_none or b_none: + if a_none and b_none: + return True + else: + return False + # the str() is for UnicodeCharSeq, it's a nop else + a = str(a) + b = str(b) + if len(a) != len(b): + return False + return _cmp_region(a, 0, b, 0, len(a)) == 0 + + return eq_impl + elif a_unicode ^ b_unicode: + # one of the things is unicode, everything compares False + def eq_impl(a, b): + return False + + return eq_impl + + +@overload(operator.ne) +def unicode_ne(a, b): + if not (a.is_internal and b.is_internal): + return + accept = (types.UnicodeType, types.StringLiteral, types.UnicodeCharSeq) + a_unicode = isinstance(a, accept) + b_unicode = isinstance(b, accept) + if a_unicode and b_unicode: + + def ne_impl(a, b): + return not (a == b) + + return ne_impl + elif a_unicode ^ b_unicode: + # one of the things is unicode, everything compares True + def eq_impl(a, b): + return True + + return eq_impl + + +@overload(operator.lt) +def unicode_lt(a, b): + a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral)) + b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral)) + if a_unicode and b_unicode: + + def lt_impl(a, b): + minlen = min(len(a), len(b)) + eqcode = _cmp_region(a, 0, b, 0, minlen) + if eqcode == -1: + return True + elif eqcode == 0: + return len(a) < len(b) + return False + + return lt_impl + + +@overload(operator.gt) +def unicode_gt(a, b): + a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral)) + b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral)) + if a_unicode and b_unicode: + + def gt_impl(a, b): + minlen = min(len(a), len(b)) + eqcode = _cmp_region(a, 0, b, 0, minlen) + if eqcode == 1: + return True + elif eqcode == 0: + return len(a) > len(b) + return False + + return gt_impl + + +@overload(operator.le) +def unicode_le(a, b): + a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral)) + b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral)) + if a_unicode and b_unicode: + + def le_impl(a, b): + return not (a > b) + + return le_impl + + +@overload(operator.ge) +def unicode_ge(a, b): + a_unicode = isinstance(a, (types.UnicodeType, types.StringLiteral)) + b_unicode = isinstance(b, (types.UnicodeType, types.StringLiteral)) + if a_unicode and b_unicode: + + def ge_impl(a, b): + return not (a < b) + + return ge_impl + + +@overload(operator.contains) +def unicode_contains(a, b): + if isinstance(a, types.UnicodeType) and isinstance(b, types.UnicodeType): + + def contains_impl(a, b): + # note parameter swap: contains(a, b) == b in a + return _find(a, b) > -1 + + return contains_impl + + +def unicode_idx_check_type(ty, name): + """Check object belongs to one of specific types + ty: type + Type of the object + name: str + Name of the object + """ + thety = ty + # if the type is omitted, the concrete type is the value + if isinstance(ty, types.Omitted): + thety = ty.value + # if the type is optional, the concrete type is the captured type + elif isinstance(ty, types.Optional): + thety = ty.type + + accepted = (types.Integer, types.NoneType) + if thety is not None and not isinstance(thety, accepted): + raise TypingError('"{}" must be {}, not {}'.format(name, accepted, ty)) + + +def unicode_sub_check_type(ty, name): + """Check object belongs to unicode type""" + if not isinstance(ty, types.UnicodeType): + msg = '"{}" must be {}, not {}'.format(name, types.UnicodeType, ty) + raise TypingError(msg) + + +# FAST SEARCH algorithm implementation from cpython + + +@register_jitable +def _bloom_add(mask, ch): + mask |= 1 << (ch & (_BLOOM_WIDTH - 1)) + return mask + + +@register_jitable +def _bloom_check(mask, ch): + return mask & (1 << (ch & (_BLOOM_WIDTH - 1))) + + +# https://github.com/python/cpython/blob/1960eb005e04b7ad8a91018088cfdb0646bc1ca0/Objects/stringlib/fastsearch.h#L550 # noqa: E501 +@register_jitable +def _default_find(data, substr, start, end): + """Left finder.""" + m = len(substr) + if m == 0: + return start + + gap = mlast = m - 1 + last = _get_code_point(substr, mlast) + + zero = types.intp(0) + mask = _bloom_add(zero, last) + for i in range(mlast): + ch = _get_code_point(substr, i) + mask = _bloom_add(mask, ch) + if ch == last: + gap = mlast - i - 1 + + i = start + while i <= end - m: + ch = _get_code_point(data, mlast + i) + if ch == last: + j = 0 + while j < mlast: + haystack_ch = _get_code_point(data, i + j) + needle_ch = _get_code_point(substr, j) + if haystack_ch != needle_ch: + break + j += 1 + if j == mlast: + # got a match + return i + + ch = _get_code_point(data, mlast + i + 1) + if _bloom_check(mask, ch) == 0: + i += m + else: + i += gap + else: + ch = _get_code_point(data, mlast + i + 1) + if _bloom_check(mask, ch) == 0: + i += m + i += 1 + + return -1 + + +@register_jitable +def _default_rfind(data, substr, start, end): + """Right finder.""" + m = len(substr) + if m == 0: + return end + + skip = mlast = m - 1 + mfirst = _get_code_point(substr, 0) + mask = _bloom_add(0, mfirst) + i = mlast + while i > 0: + ch = _get_code_point(substr, i) + mask = _bloom_add(mask, ch) + if ch == mfirst: + skip = i - 1 + i -= 1 + + i = end - m + while i >= start: + ch = _get_code_point(data, i) + if ch == mfirst: + j = mlast + while j > 0: + haystack_ch = _get_code_point(data, i + j) + needle_ch = _get_code_point(substr, j) + if haystack_ch != needle_ch: + break + j -= 1 + + if j == 0: + # got a match + return i + + ch = _get_code_point(data, i - 1) + if i > start and _bloom_check(mask, ch) == 0: + i -= m + else: + i -= skip + + else: + ch = _get_code_point(data, i - 1) + if i > start and _bloom_check(mask, ch) == 0: + i -= m + i -= 1 + + return -1 + + +def generate_finder(find_func): + """Generate finder either left or right.""" + + def impl(data, substr, start=None, end=None): + length = len(data) + sub_length = len(substr) + if start is None: + start = 0 + if end is None: + end = length + + start, end = _adjust_indices(length, start, end) + if end - start < sub_length: + return -1 + + return find_func(data, substr, start, end) + + return impl + + +_find = register_jitable(generate_finder(_default_find)) +_rfind = register_jitable(generate_finder(_default_rfind)) + + +@overload_method(types.UnicodeType, "find") +def unicode_find(data, substr, start=None, end=None): + """Implements str.find()""" + if isinstance(substr, types.UnicodeCharSeq): + + def find_impl(data, substr, start=None, end=None): + return data.find(str(substr)) + + return find_impl + + unicode_idx_check_type(start, "start") + unicode_idx_check_type(end, "end") + unicode_sub_check_type(substr, "substr") + + return _find + + +@overload_method(types.UnicodeType, "rfind") +def unicode_rfind(data, substr, start=None, end=None): + """Implements str.rfind()""" + if isinstance(substr, types.UnicodeCharSeq): + + def rfind_impl(data, substr, start=None, end=None): + return data.rfind(str(substr)) + + return rfind_impl + + unicode_idx_check_type(start, "start") + unicode_idx_check_type(end, "end") + unicode_sub_check_type(substr, "substr") + + return _rfind + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12831-L12857 # noqa: E501 +@overload_method(types.UnicodeType, "rindex") +def unicode_rindex(s, sub, start=None, end=None): + """Implements str.rindex()""" + unicode_idx_check_type(start, "start") + unicode_idx_check_type(end, "end") + unicode_sub_check_type(sub, "sub") + + def rindex_impl(s, sub, start=None, end=None): + result = s.rfind(sub, start, end) + if result < 0: + raise ValueError("substring not found") + + return result + + return rindex_impl + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11692-L11718 # noqa: E501 +@overload_method(types.UnicodeType, "index") +def unicode_index(s, sub, start=None, end=None): + """Implements str.index()""" + unicode_idx_check_type(start, "start") + unicode_idx_check_type(end, "end") + unicode_sub_check_type(sub, "sub") + + def index_impl(s, sub, start=None, end=None): + result = s.find(sub, start, end) + if result < 0: + raise ValueError("substring not found") + + return result + + return index_impl + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12922-L12976 # noqa: E501 +@overload_method(types.UnicodeType, "partition") +def unicode_partition(data, sep): + """Implements str.partition()""" + thety = sep + # if the type is omitted, the concrete type is the value + if isinstance(sep, types.Omitted): + thety = sep.value + # if the type is optional, the concrete type is the captured type + elif isinstance(sep, types.Optional): + thety = sep.type + + accepted = (types.UnicodeType, types.UnicodeCharSeq) + if thety is not None and not isinstance(thety, accepted): + msg = '"{}" must be {}, not {}'.format("sep", accepted, sep) + raise TypingError(msg) + + def impl(data, sep): + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/partition.h#L7-L60 # noqa: E501 + sep = str(sep) + empty_str = _empty_string(data._kind, 0, data._is_ascii) + sep_length = len(sep) + if data._kind < sep._kind or len(data) < sep_length: + return data, empty_str, empty_str + + if sep_length == 0: + raise ValueError("empty separator") + + pos = data.find(sep) + if pos < 0: + return data, empty_str, empty_str + + return data[0:pos], sep, data[pos + sep_length : len(data)] + + return impl + + +@overload_method(types.UnicodeType, "count") +def unicode_count(src, sub, start=None, end=None): + _count_args_types_check(start) + _count_args_types_check(end) + + if isinstance(sub, types.UnicodeType): + + def count_impl(src, sub, start=None, end=None): + count = 0 + src_len = len(src) + sub_len = len(sub) + + start = _normalize_slice_idx_count(start, src_len, 0) + end = _normalize_slice_idx_count(end, src_len, src_len) + + if end - start < 0 or start > src_len: + return 0 + + src = src[start:end] + src_len = len(src) + start, end = 0, src_len + if sub_len == 0: + return src_len + 1 + + while start + sub_len <= src_len: + if src[start : start + sub_len] == sub: + count += 1 + start += sub_len + else: + start += 1 + return count + + return count_impl + error_msg = "The substring must be a UnicodeType, not {}" + raise TypingError(error_msg.format(type(sub))) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12979-L13033 # noqa: E501 +@overload_method(types.UnicodeType, "rpartition") +def unicode_rpartition(data, sep): + """Implements str.rpartition()""" + thety = sep + # if the type is omitted, the concrete type is the value + if isinstance(sep, types.Omitted): + thety = sep.value + # if the type is optional, the concrete type is the captured type + elif isinstance(sep, types.Optional): + thety = sep.type + + accepted = (types.UnicodeType, types.UnicodeCharSeq) + if thety is not None and not isinstance(thety, accepted): + msg = '"{}" must be {}, not {}'.format("sep", accepted, sep) + raise TypingError(msg) + + def impl(data, sep): + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/partition.h#L62-L115 # noqa: E501 + sep = str(sep) + empty_str = _empty_string(data._kind, 0, data._is_ascii) + sep_length = len(sep) + if data._kind < sep._kind or len(data) < sep_length: + return empty_str, empty_str, data + + if sep_length == 0: + raise ValueError("empty separator") + + pos = data.rfind(sep) + if pos < 0: + return empty_str, empty_str, data + + return data[0:pos], sep, data[pos + sep_length : len(data)] + + return impl + + +# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9342-L9354 # noqa: E501 +@register_jitable +def _adjust_indices(length, start, end): + if end > length: + end = length + if end < 0: + end += length + if end < 0: + end = 0 + if start < 0: + start += length + if start < 0: + start = 0 + + return start, end + + +@overload_method(types.UnicodeType, "startswith") +def unicode_startswith(s, prefix, start=None, end=None): + if not is_nonelike(start) and not isinstance(start, types.Integer): + raise TypingError( + "When specified, the arg 'start' must be an Integer or None" + ) + + if not is_nonelike(end) and not isinstance(end, types.Integer): + raise TypingError( + "When specified, the arg 'end' must be an Integer or None" + ) + + if isinstance(prefix, types.UniTuple) and isinstance( + prefix.dtype, types.UnicodeType + ): + + def startswith_tuple_impl(s, prefix, start=None, end=None): + for item in prefix: + if s.startswith(item, start, end): + return True + return False + + return startswith_tuple_impl + + elif isinstance(prefix, types.UnicodeCharSeq): + + def startswith_char_seq_impl(s, prefix, start=None, end=None): + return s.startswith(str(prefix), start, end) + + return startswith_char_seq_impl + + elif isinstance(prefix, types.UnicodeType): + + def startswith_unicode_impl(s, prefix, start=None, end=None): + length, prefix_length = len(s), len(prefix) + if start is None: + start = 0 + if end is None: + end = length + + start, end = _adjust_indices(length, start, end) + if end - start < prefix_length: + return False + + if prefix_length == 0: + return True + + s_slice = s[start:end] + + return _cmp_region(s_slice, 0, prefix, 0, prefix_length) == 0 + + return startswith_unicode_impl + + else: + raise TypingError( + "The arg 'prefix' should be a string or a tuple of strings" + ) + + +@overload_method(types.UnicodeType, "endswith") +def unicode_endswith(s, substr, start=None, end=None): + if not ( + start is None + or isinstance(start, (types.Omitted, types.Integer, types.NoneType)) + ): + raise TypingError("The arg must be a Integer or None") + + if not ( + end is None + or isinstance(end, (types.Omitted, types.Integer, types.NoneType)) + ): + raise TypingError("The arg must be a Integer or None") + + if isinstance(substr, (types.Tuple, types.UniTuple)): + + def endswith_impl(s, substr, start=None, end=None): + for item in substr: + if s.endswith(item, start, end) is True: + return True + + return False + + return endswith_impl + + if isinstance(substr, types.UnicodeType): + + def endswith_impl(s, substr, start=None, end=None): + length = len(s) + sub_length = len(substr) + if start is None: + start = 0 + if end is None: + end = length + + start, end = _adjust_indices(length, start, end) + if end - start < sub_length: + return False + + if sub_length == 0: + return True + + s = s[start:end] + offset = len(s) - sub_length + + return _cmp_region(s, offset, substr, 0, sub_length) == 0 + + return endswith_impl + + if isinstance(substr, types.UnicodeCharSeq): + + def endswith_impl(s, substr, start=None, end=None): + return s.endswith(str(substr), start, end) + + return endswith_impl + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11519-L11595 # noqa: E501 +@overload_method(types.UnicodeType, "expandtabs") +def unicode_expandtabs(data, tabsize=8): + """Implements str.expandtabs()""" + thety = tabsize + # if the type is omitted, the concrete type is the value + if isinstance(tabsize, types.Omitted): + thety = tabsize.value + # if the type is optional, the concrete type is the captured type + elif isinstance(tabsize, types.Optional): + thety = tabsize.type + + accepted = (types.Integer, int) + if thety is not None and not isinstance(thety, accepted): + raise TypingError( + '"tabsize" must be {}, not {}'.format(accepted, tabsize) + ) + + def expandtabs_impl(data, tabsize=8): + length = len(data) + j = line_pos = 0 + found = False + for i in range(length): + code_point = _get_code_point(data, i) + if code_point == _Py_TAB: + found = True + if tabsize > 0: + # cannot overflow + incr = tabsize - (line_pos % tabsize) + if j > sys.maxsize - incr: + raise OverflowError("new string is too long") + line_pos += incr + j += incr + else: + if j > sys.maxsize - 1: + raise OverflowError("new string is too long") + line_pos += 1 + j += 1 + if code_point in (_Py_LINEFEED, _Py_CARRIAGE_RETURN): + line_pos = 0 + + if not found: + return data + + res = _empty_string(data._kind, j, data._is_ascii) + j = line_pos = 0 + for i in range(length): + code_point = _get_code_point(data, i) + if code_point == _Py_TAB: + if tabsize > 0: + incr = tabsize - (line_pos % tabsize) + line_pos += incr + for idx in range(j, j + incr): + _set_code_point(res, idx, _Py_SPACE) + j += incr + else: + line_pos += 1 + _set_code_point(res, j, code_point) + j += 1 + if code_point in (_Py_LINEFEED, _Py_CARRIAGE_RETURN): + line_pos = 0 + + return res + + return expandtabs_impl + + +@overload_method(types.UnicodeType, "split") +def unicode_split(a, sep=None, maxsplit=-1): + if not ( + maxsplit == -1 + or isinstance( + maxsplit, (types.Omitted, types.Integer, types.IntegerLiteral) + ) + ): + return None # fail typing if maxsplit is not an integer + + if isinstance(sep, types.UnicodeCharSeq): + + def split_impl(a, sep=None, maxsplit=-1): + return a.split(str(sep), maxsplit=maxsplit) + + return split_impl + + if isinstance(sep, types.UnicodeType): + + def split_impl(a, sep=None, maxsplit=-1): + a_len = len(a) + sep_len = len(sep) + + if sep_len == 0: + raise ValueError("empty separator") + + parts = [] + last = 0 + idx = 0 + + if sep_len == 1 and maxsplit == -1: + sep_code_point = _get_code_point(sep, 0) + for idx in range(a_len): + if _get_code_point(a, idx) == sep_code_point: + parts.append(a[last:idx]) + last = idx + 1 + else: + split_count = 0 + + while idx < a_len and ( + maxsplit == -1 or split_count < maxsplit + ): + if _cmp_region(a, idx, sep, 0, sep_len) == 0: + parts.append(a[last:idx]) + idx += sep_len + last = idx + split_count += 1 + else: + idx += 1 + + if last <= a_len: + parts.append(a[last:]) + + return parts + + return split_impl + elif ( + sep is None + or isinstance(sep, types.NoneType) + or getattr(sep, "value", False) is None + ): + + def split_whitespace_impl(a, sep=None, maxsplit=-1): + a_len = len(a) + + parts = [] + last = 0 + idx = 0 + split_count = 0 + in_whitespace_block = True + + for idx in range(a_len): + code_point = _get_code_point(a, idx) + is_whitespace = _PyUnicode_IsSpace(code_point) + if in_whitespace_block: + if is_whitespace: + pass # keep consuming space + else: + last = idx # this is the start of the next string + in_whitespace_block = False + else: + if not is_whitespace: + pass # keep searching for whitespace transition + else: + parts.append(a[last:idx]) + in_whitespace_block = True + split_count += 1 + if maxsplit != -1 and split_count == maxsplit: + break + + if last <= a_len and not in_whitespace_block: + parts.append(a[last:]) + + return parts + + return split_whitespace_impl + + +def generate_rsplit_whitespace_impl(isspace_func): + """Generate whitespace rsplit func based on either ascii or unicode""" + + def rsplit_whitespace_impl(data, sep=None, maxsplit=-1): + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L192-L240 # noqa: E501 + if maxsplit < 0: + maxsplit = sys.maxsize + + result = [] + i = len(data) - 1 + while maxsplit > 0: + while i >= 0: + code_point = _get_code_point(data, i) + if not isspace_func(code_point): + break + i -= 1 + if i < 0: + break + j = i + i -= 1 + while i >= 0: + code_point = _get_code_point(data, i) + if isspace_func(code_point): + break + i -= 1 + result.append(data[i + 1 : j + 1]) + maxsplit -= 1 + + if i >= 0: + # Only occurs when maxsplit was reached + # Skip any remaining whitespace and copy to beginning of string + while i >= 0: + code_point = _get_code_point(data, i) + if not isspace_func(code_point): + break + i -= 1 + if i >= 0: + result.append(data[0 : i + 1]) + + return result[::-1] + + return rsplit_whitespace_impl + + +unicode_rsplit_whitespace_impl = register_jitable( + generate_rsplit_whitespace_impl(_PyUnicode_IsSpace) +) +ascii_rsplit_whitespace_impl = register_jitable( + generate_rsplit_whitespace_impl(_Py_ISSPACE) +) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L13095-L13108 # noqa: E501 +@overload_method(types.UnicodeType, "rsplit") +def unicode_rsplit(data, sep=None, maxsplit=-1): + """Implements str.unicode_rsplit()""" + + def _unicode_rsplit_check_type(ty, name, accepted): + """Check object belongs to one of specified types""" + thety = ty + # if the type is omitted, the concrete type is the value + if isinstance(ty, types.Omitted): + thety = ty.value + # if the type is optional, the concrete type is the captured type + elif isinstance(ty, types.Optional): + thety = ty.type + + if thety is not None and not isinstance(thety, accepted): + raise TypingError( + '"{}" must be {}, not {}'.format(name, accepted, ty) + ) + + _unicode_rsplit_check_type( + sep, "sep", (types.UnicodeType, types.UnicodeCharSeq, types.NoneType) + ) + _unicode_rsplit_check_type(maxsplit, "maxsplit", (types.Integer, int)) + + if sep is None or isinstance(sep, (types.NoneType, types.Omitted)): + + def rsplit_whitespace_impl(data, sep=None, maxsplit=-1): + if data._is_ascii: + return ascii_rsplit_whitespace_impl(data, sep, maxsplit) + return unicode_rsplit_whitespace_impl(data, sep, maxsplit) + + return rsplit_whitespace_impl + + def rsplit_impl(data, sep=None, maxsplit=-1): + sep = str(sep) + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L286-L333 # noqa: E501 + if data._kind < sep._kind or len(data) < len(sep): + return [data] + + def _rsplit_char(data, ch, maxsplit): + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L242-L284 # noqa: E501 + result = [] + ch_code_point = _get_code_point(ch, 0) + i = j = len(data) - 1 + while i >= 0 and maxsplit > 0: + data_code_point = _get_code_point(data, i) + if data_code_point == ch_code_point: + result.append(data[i + 1 : j + 1]) + j = i = i - 1 + maxsplit -= 1 + i -= 1 + if j >= -1: + result.append(data[0 : j + 1]) + + return result[::-1] + + if maxsplit < 0: + maxsplit = sys.maxsize + + sep_length = len(sep) + + if sep_length == 0: + raise ValueError("empty separator") + if sep_length == 1: + return _rsplit_char(data, sep, maxsplit) + + result = [] + j = len(data) + while maxsplit > 0: + pos = data.rfind(sep, start=0, end=j) + if pos < 0: + break + result.append(data[pos + sep_length : j]) + j = pos + maxsplit -= 1 + + result.append(data[0:j]) + + return result[::-1] + + return rsplit_impl + + +@overload_method(types.UnicodeType, "center") +def unicode_center(string, width, fillchar=" "): + if not isinstance(width, types.Integer): + raise TypingError("The width must be an Integer") + + if isinstance(fillchar, types.UnicodeCharSeq): + + def center_impl(string, width, fillchar=" "): + return string.center(width, str(fillchar)) + + return center_impl + + if not ( + fillchar == " " + or isinstance(fillchar, (types.Omitted, types.UnicodeType)) + ): + raise TypingError("The fillchar must be a UnicodeType") + + def center_impl(string, width, fillchar=" "): + str_len = len(string) + fillchar_len = len(fillchar) + + if fillchar_len != 1: + raise ValueError( + "The fill character must be exactly one character long" + ) + + if width <= str_len: + return string + + allmargin = width - str_len + lmargin = (allmargin // 2) + (allmargin & width & 1) + rmargin = allmargin - lmargin + + l_string = fillchar * lmargin + if lmargin == rmargin: + return l_string + string + l_string + else: + return l_string + string + (fillchar * rmargin) + + return center_impl + + +def gen_unicode_Xjust(STRING_FIRST): + def unicode_Xjust(string, width, fillchar=" "): + if not isinstance(width, types.Integer): + raise TypingError("The width must be an Integer") + + if isinstance(fillchar, types.UnicodeCharSeq): + if STRING_FIRST: + + def ljust_impl(string, width, fillchar=" "): + return string.ljust(width, str(fillchar)) + + return ljust_impl + else: + + def rjust_impl(string, width, fillchar=" "): + return string.rjust(width, str(fillchar)) + + return rjust_impl + + if not ( + fillchar == " " + or isinstance(fillchar, (types.Omitted, types.UnicodeType)) + ): + raise TypingError("The fillchar must be a UnicodeType") + + def impl(string, width, fillchar=" "): + str_len = len(string) + fillchar_len = len(fillchar) + + if fillchar_len != 1: + raise ValueError( + "The fill character must be exactly one character long" + ) + + if width <= str_len: + return string + + newstr = fillchar * (width - str_len) + if STRING_FIRST: + return string + newstr + else: + return newstr + string + + return impl + + return unicode_Xjust + + +overload_method(types.UnicodeType, "rjust")(gen_unicode_Xjust(False)) +overload_method(types.UnicodeType, "ljust")(gen_unicode_Xjust(True)) + + +def generate_splitlines_func(is_line_break_func): + """Generate splitlines performer based on ascii or unicode line breaks.""" + + def impl(data, keepends): + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/stringlib/split.h#L335-L389 # noqa: E501 + length = len(data) + result = [] + i = j = 0 + while i < length: + # find a line and append it + while i < length: + code_point = _get_code_point(data, i) + if is_line_break_func(code_point): + break + i += 1 + + # skip the line break reading CRLF as one line break + eol = i + if i < length: + if i + 1 < length: + cur_cp = _get_code_point(data, i) + next_cp = _get_code_point(data, i + 1) + if _Py_ISCARRIAGERETURN(cur_cp) and _Py_ISLINEFEED(next_cp): + i += 1 + i += 1 + if keepends: + eol = i + + result.append(data[j:eol]) + j = i + + return result + + return impl + + +_ascii_splitlines = register_jitable(generate_splitlines_func(_Py_ISLINEBREAK)) +_unicode_splitlines = register_jitable( + generate_splitlines_func(_PyUnicode_IsLineBreak) +) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L10196-L10229 # noqa: E501 +@overload_method(types.UnicodeType, "splitlines") +def unicode_splitlines(data, keepends=False): + """Implements str.splitlines()""" + thety = keepends + # if the type is omitted, the concrete type is the value + if isinstance(keepends, types.Omitted): + thety = keepends.value + # if the type is optional, the concrete type is the captured type + elif isinstance(keepends, types.Optional): + thety = keepends.type + + accepted = (types.Integer, int, types.Boolean, bool) + if thety is not None and not isinstance(thety, accepted): + raise TypingError( + '"{}" must be {}, not {}'.format("keepends", accepted, keepends) + ) + + def splitlines_impl(data, keepends=False): + if data._is_ascii: + return _ascii_splitlines(data, keepends) + + return _unicode_splitlines(data, keepends) + + return splitlines_impl + + +@register_jitable +def join_list(sep, parts): + parts_len = len(parts) + if parts_len == 0: + return "" + + # Precompute size and char_width of result + sep_len = len(sep) + length = (parts_len - 1) * sep_len + kind = sep._kind + is_ascii = sep._is_ascii + for p in parts: + length += len(p) + kind = _pick_kind(kind, p._kind) + is_ascii = _pick_ascii(is_ascii, p._is_ascii) + + result = _empty_string(kind, length, is_ascii) + + # populate string + part = parts[0] + _strncpy(result, 0, part, 0, len(part)) + dst_offset = len(part) + for idx in range(1, parts_len): + _strncpy(result, dst_offset, sep, 0, sep_len) + dst_offset += sep_len + part = parts[idx] + _strncpy(result, dst_offset, part, 0, len(part)) + dst_offset += len(part) + + return result + + +@overload_method(types.UnicodeType, "join") +def unicode_join(sep, parts): + if isinstance(parts, types.List): + if isinstance(parts.dtype, types.UnicodeType): + + def join_list_impl(sep, parts): + return join_list(sep, parts) + + return join_list_impl + elif isinstance(parts.dtype, types.UnicodeCharSeq): + + def join_list_impl(sep, parts): + _parts = [str(p) for p in parts] + return join_list(sep, _parts) + + return join_list_impl + else: + pass # lists of any other type not supported + elif isinstance(parts, types.IterableType): + + def join_iter_impl(sep, parts): + parts_list = [p for p in parts] + return sep.join(parts_list) + + return join_iter_impl + elif isinstance(parts, types.UnicodeType): + # Temporary workaround until UnicodeType is iterable + def join_str_impl(sep, parts): + parts_list = [parts[i] for i in range(len(parts))] + return join_list(sep, parts_list) + + return join_str_impl + + +@overload_method(types.UnicodeType, "zfill") +def unicode_zfill(string, width): + if not isinstance(width, types.Integer): + raise TypingError(" must be an Integer") + + def zfill_impl(string, width): + str_len = len(string) + + if width <= str_len: + return string + + first_char = string[0] if str_len else "" + padding = "0" * (width - str_len) + + if first_char in ["+", "-"]: + newstr = first_char + padding + string[1:] + else: + newstr = padding + string + + return newstr + + return zfill_impl + + +# ------------------------------------------------------------------------------ +# Strip functions +# ------------------------------------------------------------------------------ +@register_jitable +def unicode_strip_left_bound(string, chars): + str_len = len(string) + + i = 0 + if chars is not None: + for i in range(str_len): + if string[i] not in chars: + return i + else: + for i in range(str_len): + if not _PyUnicode_IsSpace(string[i]): + return i + + return str_len + + +@register_jitable +def unicode_strip_right_bound(string, chars): + str_len = len(string) + i = 0 + if chars is not None: + for i in range(str_len - 1, -1, -1): + if string[i] not in chars: + i += 1 + break + else: + for i in range(str_len - 1, -1, -1): + if not _PyUnicode_IsSpace(string[i]): + i += 1 + break + + return i + + +def unicode_strip_types_check(chars): + if isinstance(chars, types.Optional): + chars = chars.type # catch optional type with invalid non-None type + if not ( + chars is None + or isinstance(chars, (types.Omitted, types.UnicodeType, types.NoneType)) + ): + raise TypingError("The arg must be a UnicodeType or None") + + +def _count_args_types_check(arg): + if isinstance(arg, types.Optional): + arg = arg.type + if not ( + arg is None + or isinstance(arg, (types.Omitted, types.Integer, types.NoneType)) + ): + raise TypingError("The slice indices must be an Integer or None") + + +@overload_method(types.UnicodeType, "lstrip") +def unicode_lstrip(string, chars=None): + if isinstance(chars, types.UnicodeCharSeq): + + def lstrip_impl(string, chars=None): + return string.lstrip(str(chars)) + + return lstrip_impl + + unicode_strip_types_check(chars) + + def lstrip_impl(string, chars=None): + return string[unicode_strip_left_bound(string, chars) :] + + return lstrip_impl + + +@overload_method(types.UnicodeType, "rstrip") +def unicode_rstrip(string, chars=None): + if isinstance(chars, types.UnicodeCharSeq): + + def rstrip_impl(string, chars=None): + return string.rstrip(str(chars)) + + return rstrip_impl + + unicode_strip_types_check(chars) + + def rstrip_impl(string, chars=None): + return string[: unicode_strip_right_bound(string, chars)] + + return rstrip_impl + + +@overload_method(types.UnicodeType, "strip") +def unicode_strip(string, chars=None): + if isinstance(chars, types.UnicodeCharSeq): + + def strip_impl(string, chars=None): + return string.strip(str(chars)) + + return strip_impl + + unicode_strip_types_check(chars) + + def strip_impl(string, chars=None): + lb = unicode_strip_left_bound(string, chars) + rb = unicode_strip_right_bound(string, chars) + return string[lb:rb] + + return strip_impl + + +# ------------------------------------------------------------------------------ +# Slice functions +# ------------------------------------------------------------------------------ + + +@register_jitable +def normalize_str_idx(idx, length, is_start=True): + """ + Parameters + ---------- + idx : int or None + the index + length : int + the string length + is_start : bool; optional with defaults to True + Is it the *start* or the *stop* of the slice? + + Returns + ------- + norm_idx : int + normalized index + """ + if idx is None: + if is_start: + return 0 + else: + return length + elif idx < 0: + idx += length + + if idx < 0 or idx >= length: + raise IndexError("string index out of range") + + return idx + + +@register_jitable +def _normalize_slice_idx_count(arg, slice_len, default): + """ + Used for unicode_count + + If arg < -slice_len, returns 0 (prevents circle) + + If arg is within slice, e.g -slice_len <= arg < slice_len + returns its real index via arg % slice_len + + If arg > slice_len, returns arg (in this case count must + return 0 if it is start index) + """ + + if arg is None: + return default + if -slice_len <= arg < slice_len: + return arg % slice_len + return 0 if arg < 0 else arg + + +@intrinsic +def _normalize_slice(typingctx, sliceobj, length): + """Fix slice object.""" + sig = sliceobj(sliceobj, length) + + def codegen(context, builder, sig, args): + [slicetype, lengthtype] = sig.args + [sliceobj, length] = args + slice = context.make_helper(builder, slicetype, sliceobj) + slicing.guard_invalid_slice(context, builder, slicetype, slice) + slicing.fix_slice(builder, slice, length) + return slice._getvalue() + + return sig, codegen + + +@intrinsic +def _slice_span(typingctx, sliceobj): + """Compute the span from the given slice object.""" + sig = types.intp(sliceobj) + + def codegen(context, builder, sig, args): + [slicetype] = sig.args + [sliceobj] = args + slice = context.make_helper(builder, slicetype, sliceobj) + result_size = slicing.get_slice_length(builder, slice) + return result_size + + return sig, codegen + + +@register_jitable(_nrt=False) +def _strncpy(dst, dst_offset, src, src_offset, n): + if src._kind == dst._kind: + byte_width = _kind_to_byte_width(src._kind) + src_byte_offset = byte_width * src_offset + dst_byte_offset = byte_width * dst_offset + nbytes = n * byte_width + memcpy_region( + dst._data, + dst_byte_offset, + src._data, + src_byte_offset, + nbytes, + align=1, + ) + else: + for i in range(n): + _set_code_point( + dst, dst_offset + i, _get_code_point(src, src_offset + i) + ) + + +@intrinsic +def _get_str_slice_view(typingctx, src_t, start_t, length_t): + """Create a slice of a unicode string using a view of its data to avoid + extra allocation. + """ + assert src_t == types.unicode_type + + def codegen(context, builder, sig, args): + src, start, length = args + in_str = cgutils.create_struct_proxy(types.unicode_type)( + context, builder, value=src + ) + view_str = cgutils.create_struct_proxy(types.unicode_type)( + context, builder + ) + view_str.meminfo = in_str.meminfo + view_str.kind = in_str.kind + view_str.is_ascii = in_str.is_ascii + view_str.length = length + # hash value -1 to indicate "need to compute hash" + view_str.hash = context.get_constant(_Py_hash_t, -1) + # get a pointer to start of slice data + bw_typ = context.typing_context.resolve_value_type(_kind_to_byte_width) + bw_sig = bw_typ.get_call_type( + context.typing_context, (types.int32,), {} + ) + bw_impl = context.get_function(bw_typ, bw_sig) + byte_width = bw_impl(builder, (in_str.kind,)) + offset = builder.mul(start, byte_width) + view_str.data = builder.gep(in_str.data, [offset]) + # Set parent pyobject to NULL + view_str.parent = cgutils.get_null_value(view_str.parent.type) + # incref original string + if context.enable_nrt: + context.nrt.incref(builder, sig.args[0], src) + return view_str._getvalue() + + sig = types.unicode_type(types.unicode_type, types.intp, types.intp) + return sig, codegen + + +@overload(operator.getitem) +def unicode_getitem(s, idx): + if isinstance(s, types.UnicodeType): + if isinstance(idx, types.Integer): + + def getitem_char(s, idx): + idx = normalize_str_idx(idx, len(s)) + cp = _get_code_point(s, idx) + kind = _codepoint_to_kind(cp) + if kind == s._kind: + return _get_str_slice_view(s, idx, 1) + else: + is_ascii = _codepoint_is_ascii(cp) + ret = _empty_string(kind, 1, is_ascii) + _set_code_point(ret, 0, cp) + return ret + + return getitem_char + elif isinstance(idx, types.SliceType): + + def getitem_slice(s, idx): + slice_idx = _normalize_slice(idx, len(s)) + span = _slice_span(slice_idx) + + cp = _get_code_point(s, slice_idx.start) + kind = _codepoint_to_kind(cp) + is_ascii = _codepoint_is_ascii(cp) + + # Check slice to see if it's homogeneous in kind + for i in range( + slice_idx.start + slice_idx.step, + slice_idx.stop, + slice_idx.step, + ): + cp = _get_code_point(s, i) + is_ascii &= _codepoint_is_ascii(cp) + new_kind = _codepoint_to_kind(cp) + if kind != new_kind: + kind = _pick_kind(kind, new_kind) + # TODO: it might be possible to break here if the kind + # is PY_UNICODE_4BYTE_KIND but there are potentially + # strings coming from other internal functions that are + # this wide and also actually ASCII (i.e. kind is larger + # than actually required for storing the code point), so + # it's necessary to continue. + + if slice_idx.step == 1 and kind == s._kind: + # Can return a view, the slice has the same kind as the + # string itself and it's a stride slice 1. + return _get_str_slice_view(s, slice_idx.start, span) + else: + # It's heterogeneous in kind OR stride != 1 + ret = _empty_string(kind, span, is_ascii) + cur = slice_idx.start + for i in range(span): + _set_code_point(ret, i, _get_code_point(s, cur)) + cur += slice_idx.step + return ret + + return getitem_slice + + +# ------------------------------------------------------------------------------ +# String operations +# ------------------------------------------------------------------------------ + + +@overload(operator.add) +@overload(operator.iadd) +def unicode_concat(a, b): + if isinstance(a, types.UnicodeType) and isinstance(b, types.UnicodeType): + + def concat_impl(a, b): + new_length = a._length + b._length + new_kind = _pick_kind(a._kind, b._kind) + new_ascii = _pick_ascii(a._is_ascii, b._is_ascii) + result = _empty_string(new_kind, new_length, new_ascii) + for i in range(len(a)): + _set_code_point(result, i, _get_code_point(a, i)) + for j in range(len(b)): + _set_code_point(result, len(a) + j, _get_code_point(b, j)) + return result + + return concat_impl + + if isinstance(a, types.UnicodeType) and isinstance(b, types.UnicodeCharSeq): + + def concat_impl(a, b): + return a + str(b) + + return concat_impl + + +@register_jitable +def _repeat_impl(str_arg, mult_arg): + if str_arg == "" or mult_arg < 1: + return "" + elif mult_arg == 1: + return str_arg + else: + new_length = str_arg._length * mult_arg + new_kind = str_arg._kind + result = _empty_string(new_kind, new_length, str_arg._is_ascii) + # make initial copy into result + len_a = len(str_arg) + _strncpy(result, 0, str_arg, 0, len_a) + # loop through powers of 2 for efficient copying + copy_size = len_a + while 2 * copy_size <= new_length: + _strncpy(result, copy_size, result, 0, copy_size) + copy_size *= 2 + + if not 2 * copy_size == new_length: + # if copy_size not an exact multiple it then needs + # to complete the rest of the copies + rest = new_length - copy_size + _strncpy(result, copy_size, result, copy_size - rest, rest) + return result + + +@overload(operator.mul) +def unicode_repeat(a, b): + if isinstance(a, types.UnicodeType) and isinstance(b, types.Integer): + + def wrap(a, b): + return _repeat_impl(a, b) + + return wrap + elif isinstance(a, types.Integer) and isinstance(b, types.UnicodeType): + + def wrap(a, b): + return _repeat_impl(b, a) + + return wrap + + +@overload(operator.not_) +def unicode_not(a): + if isinstance(a, types.UnicodeType): + + def impl(a): + return len(a) == 0 + + return impl + + +@overload_method(types.UnicodeType, "replace") +def unicode_replace(s, old_str, new_str, count=-1): + thety = count + if isinstance(count, types.Omitted): + thety = count.value + elif isinstance(count, types.Optional): + thety = count.type + + if not isinstance(thety, (int, types.Integer)): + raise TypingError( + "Unsupported parameters. The parameters " + "must be Integer. Given count: {}".format(count) + ) + + if not isinstance(old_str, (types.UnicodeType, types.NoneType)): + raise TypingError( + "The object must be a UnicodeType. Given: {}".format(old_str) + ) + + if not isinstance(new_str, types.UnicodeType): + raise TypingError( + "The object must be a UnicodeType. Given: {}".format(new_str) + ) + + def impl(s, old_str, new_str, count=-1): + if count == 0: + return s + if old_str == "": + schars = list(s) + if count == -1: + return new_str + new_str.join(schars) + new_str + split_result = [new_str] + min_count = min(len(schars), count) + for i in range(min_count): + split_result.append(schars[i]) + if i + 1 != min_count: + split_result.append(new_str) + else: + split_result.append("".join(schars[(i + 1) :])) + if count > len(schars): + split_result.append(new_str) + return "".join(split_result) + schars = s.split(old_str, count) + result = new_str.join(schars) + return result + + return impl + + +# ------------------------------------------------------------------------------ +# String `is*()` methods +# ------------------------------------------------------------------------------ + + +# generates isalpha/isalnum +def gen_isAlX(ascii_func, unicode_func): + def unicode_isAlX(data): + def impl(data): + length = len(data) + if length == 0: + return False + + if length == 1: + code_point = _get_code_point(data, 0) + if data._is_ascii: + return ascii_func(code_point) + else: + return unicode_func(code_point) + + if data._is_ascii: + for i in range(length): + code_point = _get_code_point(data, i) + if not ascii_func(code_point): + return False + + for i in range(length): + code_point = _get_code_point(data, i) + if not unicode_func(code_point): + return False + + return True + + return impl + + return unicode_isAlX + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11928-L11964 # noqa: E501 +overload_method(types.UnicodeType, "isalpha")( + gen_isAlX(_Py_ISALPHA, _PyUnicode_IsAlpha) +) + +_unicode_is_alnum = register_jitable( + lambda x: (_PyUnicode_IsNumeric(x) or _PyUnicode_IsAlpha(x)) +) + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11975-L12006 # noqa: E501 +overload_method(types.UnicodeType, "isalnum")( + gen_isAlX(_Py_ISALNUM, _unicode_is_alnum) +) + + +def _is_upper(is_lower, is_upper, is_title): + # impl is an approximate translation of: + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11794-L11827 # noqa: E501 + # mixed with: + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L218-L242 # noqa: E501 + def impl(a): + l = len(a) + if l == 1: + return is_upper(_get_code_point(a, 0)) != 0 + if l == 0: + return False + cased = False + for idx in range(l): + code_point = _get_code_point(a, idx) + if is_lower(code_point) or is_title(code_point): + return False + elif not cased and is_upper(code_point): + cased = True + return cased + + return impl + + +_always_false = register_jitable(lambda x: False) +_ascii_is_upper = register_jitable( + _is_upper(_Py_ISLOWER, _Py_ISUPPER, _always_false) +) +_unicode_is_upper = register_jitable( + _is_upper( + _PyUnicode_IsLowercase, _PyUnicode_IsUppercase, _PyUnicode_IsTitlecase + ) +) + + +@overload_method(types.UnicodeType, "isupper") +def unicode_isupper(a): + """ + Implements .isupper() + """ + + def impl(a): + if a._is_ascii: + return _ascii_is_upper(a) + else: + return _unicode_is_upper(a) + + return impl + + +@overload_method(types.UnicodeType, "isascii") +def unicode_isascii(data): + """Implements UnicodeType.isascii()""" + + def impl(data): + return data._is_ascii + + return impl + + +@overload_method(types.UnicodeType, "istitle") +def unicode_istitle(data): + """ + Implements UnicodeType.istitle() + The algorithm is an approximate translation from CPython: + https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11829-L11885 # noqa: E501 + """ + + def impl(data): + length = len(data) + if length == 1: + char = _get_code_point(data, 0) + return _PyUnicode_IsUppercase(char) or _PyUnicode_IsTitlecase(char) + + if length == 0: + return False + + cased = False + previous_is_cased = False + for idx in range(length): + char = _get_code_point(data, idx) + if _PyUnicode_IsUppercase(char) or _PyUnicode_IsTitlecase(char): + if previous_is_cased: + return False + previous_is_cased = True + cased = True + elif _PyUnicode_IsLowercase(char): + if not previous_is_cased: + return False + previous_is_cased = True + cased = True + else: + previous_is_cased = False + + return cased + + return impl + + +@overload_method(types.UnicodeType, "islower") +def unicode_islower(data): + """ + impl is an approximate translation of: + https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L11900-L11933 # noqa: E501 + mixed with: + https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/bytes_methods.c#L131-L156 # noqa: E501 + """ + + def impl(data): + length = len(data) + if length == 1: + return _PyUnicode_IsLowercase(_get_code_point(data, 0)) + if length == 0: + return False + + cased = False + for idx in range(length): + cp = _get_code_point(data, idx) + if _PyUnicode_IsUppercase(cp) or _PyUnicode_IsTitlecase(cp): + return False + elif not cased and _PyUnicode_IsLowercase(cp): + cased = True + return cased + + return impl + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12126-L12161 # noqa: E501 +@overload_method(types.UnicodeType, "isidentifier") +def unicode_isidentifier(data): + """Implements UnicodeType.isidentifier()""" + + def impl(data): + length = len(data) + if length == 0: + return False + + first_cp = _get_code_point(data, 0) + if not _PyUnicode_IsXidStart(first_cp) and first_cp != 0x5F: + return False + + for i in range(1, length): + code_point = _get_code_point(data, i) + if not _PyUnicode_IsXidContinue(code_point): + return False + + return True + + return impl + + +# generator for simple unicode "isX" methods +def gen_isX(_PyUnicode_IS_func, empty_is_false=True): + def unicode_isX(data): + def impl(data): + length = len(data) + if length == 1: + return _PyUnicode_IS_func(_get_code_point(data, 0)) + + if empty_is_false and length == 0: + return False + + for i in range(length): + code_point = _get_code_point(data, i) + if not _PyUnicode_IS_func(code_point): + return False + + return True + + return impl + + return unicode_isX + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L11896-L11925 # noqa: E501 +overload_method(types.UnicodeType, "isspace")(gen_isX(_PyUnicode_IsSpace)) + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12096-L12124 # noqa: E501 +overload_method(types.UnicodeType, "isnumeric")(gen_isX(_PyUnicode_IsNumeric)) + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12056-L12085 # noqa: E501 +overload_method(types.UnicodeType, "isdigit")(gen_isX(_PyUnicode_IsDigit)) + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12017-L12045 # noqa: E501 +overload_method(types.UnicodeType, "isdecimal")( + gen_isX(_PyUnicode_IsDecimalDigit) +) + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L12188-L12213 # noqa: E501 +overload_method(types.UnicodeType, "isprintable")( + gen_isX(_PyUnicode_IsPrintable, False) +) + +# ------------------------------------------------------------------------------ +# String methods that apply a transformation to the characters themselves +# ------------------------------------------------------------------------------ + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9863-L9908 # noqa: E501 +def case_operation(ascii_func, unicode_func): + """Generate common case operation performer.""" + + def impl(data): + length = len(data) + if length == 0: + return _empty_string(data._kind, length, data._is_ascii) + + if data._is_ascii: + res = _empty_string(data._kind, length, 1) + ascii_func(data, res) + return res + + # https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9863-L9908 # noqa: E501 + tmp = _empty_string(PY_UNICODE_4BYTE_KIND, 3 * length, data._is_ascii) + # maxchar should be inside of a list to be pass as argument by reference + maxchars = [0] + newlength = unicode_func(data, length, tmp, maxchars) + maxchar = maxchars[0] + newkind = _codepoint_to_kind(maxchar) + res = _empty_string(newkind, newlength, _codepoint_is_ascii(maxchar)) + for i in range(newlength): + _set_code_point(res, i, _get_code_point(tmp, i)) + + return res + + return impl + + +# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9856-L9883 # noqa: E501 +@register_jitable +def _handle_capital_sigma(data, length, idx): + """This is a translation of the function that handles the capital sigma.""" + c = 0 + j = idx - 1 + while j >= 0: + c = _get_code_point(data, j) + if not _PyUnicode_IsCaseIgnorable(c): + break + j -= 1 + final_sigma = j >= 0 and _PyUnicode_IsCased(c) + if final_sigma: + j = idx + 1 + while j < length: + c = _get_code_point(data, j) + if not _PyUnicode_IsCaseIgnorable(c): + break + j += 1 + final_sigma = j == length or (not _PyUnicode_IsCased(c)) + + return 0x3C2 if final_sigma else 0x3C3 + + +# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9885-L9895 # noqa: E501 +@register_jitable +def _lower_ucs4(code_point, data, length, idx, mapped): + """This is a translation of the function that lowers a character.""" + if code_point == 0x3A3: + mapped[0] = _handle_capital_sigma(data, length, idx) + return 1 + return _PyUnicode_ToLowerFull(code_point, mapped) + + +# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9946-L9965 # noqa: E501 +def _gen_unicode_upper_or_lower(lower): + def _do_upper_or_lower(data, length, res, maxchars): + k = 0 + for idx in range(length): + mapped = np.zeros(3, dtype=_Py_UCS4) + code_point = _get_code_point(data, idx) + if lower: + n_res = _lower_ucs4(code_point, data, length, idx, mapped) + else: + # might be needed if call _do_upper_or_lower in unicode_upper + n_res = _PyUnicode_ToUpperFull(code_point, mapped) + for m in mapped[:n_res]: + maxchars[0] = max(maxchars[0], m) + _set_code_point(res, k, m) + k += 1 + return k + + return _do_upper_or_lower + + +_unicode_upper = register_jitable(_gen_unicode_upper_or_lower(False)) +_unicode_lower = register_jitable(_gen_unicode_upper_or_lower(True)) + + +def _gen_ascii_upper_or_lower(func): + def _ascii_upper_or_lower(data, res): + for idx in range(len(data)): + code_point = _get_code_point(data, idx) + _set_code_point(res, idx, func(code_point)) + + return _ascii_upper_or_lower + + +_ascii_upper = register_jitable(_gen_ascii_upper_or_lower(_Py_TOUPPER)) +_ascii_lower = register_jitable(_gen_ascii_upper_or_lower(_Py_TOLOWER)) + + +@overload_method(types.UnicodeType, "lower") +def unicode_lower(data): + """Implements .lower()""" + return case_operation(_ascii_lower, _unicode_lower) + + +@overload_method(types.UnicodeType, "upper") +def unicode_upper(data): + """Implements .upper()""" + return case_operation(_ascii_upper, _unicode_upper) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9819-L9834 # noqa: E501 +@register_jitable +def _unicode_casefold(data, length, res, maxchars): + k = 0 + mapped = np.zeros(3, dtype=_Py_UCS4) + for idx in range(length): + mapped.fill(0) + code_point = _get_code_point(data, idx) + n_res = _PyUnicode_ToFoldedFull(code_point, mapped) + for m in mapped[:n_res]: + maxchar = maxchars[0] + maxchars[0] = max(maxchar, m) + _set_code_point(res, k, m) + k += 1 + + return k + + +@register_jitable +def _ascii_casefold(data, res): + for idx in range(len(data)): + code_point = _get_code_point(data, idx) + _set_code_point(res, idx, _Py_TOLOWER(code_point)) + + +@overload_method(types.UnicodeType, "casefold") +def unicode_casefold(data): + """Implements str.casefold()""" + return case_operation(_ascii_casefold, _unicode_casefold) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9737-L9759 # noqa: E501 +@register_jitable +def _unicode_capitalize(data, length, res, maxchars): + k = 0 + maxchar = 0 + mapped = np.zeros(3, dtype=_Py_UCS4) + code_point = _get_code_point(data, 0) + + n_res = _PyUnicode_ToTitleFull(code_point, mapped) + + for m in mapped[:n_res]: + maxchar = max(maxchar, m) + _set_code_point(res, k, m) + k += 1 + for idx in range(1, length): + mapped.fill(0) + code_point = _get_code_point(data, idx) + n_res = _lower_ucs4(code_point, data, length, idx, mapped) + for m in mapped[:n_res]: + maxchar = max(maxchar, m) + _set_code_point(res, k, m) + k += 1 + maxchars[0] = maxchar + return k + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L361-L382 # noqa: E501 +@register_jitable +def _ascii_capitalize(data, res): + code_point = _get_code_point(data, 0) + _set_code_point(res, 0, _Py_TOUPPER(code_point)) + for idx in range(1, len(data)): + code_point = _get_code_point(data, idx) + _set_code_point(res, idx, _Py_TOLOWER(code_point)) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L10765-L10774 # noqa: E501 +@overload_method(types.UnicodeType, "capitalize") +def unicode_capitalize(data): + return case_operation(_ascii_capitalize, _unicode_capitalize) + + +# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L9996-L10021 # noqa: E501 +@register_jitable +def _unicode_title(data, length, res, maxchars): + """This is a translation of the function that titles a unicode string.""" + k = 0 + previous_cased = False + mapped = np.empty(3, dtype=_Py_UCS4) + for idx in range(length): + mapped.fill(0) + code_point = _get_code_point(data, idx) + if previous_cased: + n_res = _lower_ucs4(code_point, data, length, idx, mapped) + else: + n_res = _PyUnicode_ToTitleFull(_Py_UCS4(code_point), mapped) + for m in mapped[:n_res]: + (maxchar,) = maxchars + maxchars[0] = max(maxchar, m) + _set_code_point(res, k, m) + k += 1 + previous_cased = _PyUnicode_IsCased(_Py_UCS4(code_point)) + return k + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L332-L352 # noqa: E501 +@register_jitable +def _ascii_title(data, res): + """Does .title() on an ASCII string""" + previous_is_cased = False + for idx in range(len(data)): + code_point = _get_code_point(data, idx) + if _Py_ISLOWER(code_point): + if not previous_is_cased: + code_point = _Py_TOUPPER(code_point) + previous_is_cased = True + elif _Py_ISUPPER(code_point): + if previous_is_cased: + code_point = _Py_TOLOWER(code_point) + previous_is_cased = True + else: + previous_is_cased = False + _set_code_point(res, idx, code_point) + + +# https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodeobject.c#L10023-L10069 # noqa: E501 +@overload_method(types.UnicodeType, "title") +def unicode_title(data): + """Implements str.title()""" + # https://docs.python.org/3/library/stdtypes.html#str.title + return case_operation(_ascii_title, _unicode_title) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/bytes_methods.c#L391-L408 # noqa: E501 +@register_jitable +def _ascii_swapcase(data, res): + for idx in range(len(data)): + code_point = _get_code_point(data, idx) + if _Py_ISUPPER(code_point): + code_point = _Py_TOLOWER(code_point) + elif _Py_ISLOWER(code_point): + code_point = _Py_TOUPPER(code_point) + _set_code_point(res, idx, code_point) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L9761-L9784 # noqa: E501 +@register_jitable +def _unicode_swapcase(data, length, res, maxchars): + k = 0 + maxchar = 0 + mapped = np.empty(3, dtype=_Py_UCS4) + for idx in range(length): + mapped.fill(0) + code_point = _get_code_point(data, idx) + if _PyUnicode_IsUppercase(code_point): + n_res = _lower_ucs4(code_point, data, length, idx, mapped) + elif _PyUnicode_IsLowercase(code_point): + n_res = _PyUnicode_ToUpperFull(code_point, mapped) + else: + n_res = 1 + mapped[0] = code_point + for m in mapped[:n_res]: + maxchar = max(maxchar, m) + _set_code_point(res, k, m) + k += 1 + maxchars[0] = maxchar + return k + + +@overload_method(types.UnicodeType, "swapcase") +def unicode_swapcase(data): + return case_operation(_ascii_swapcase, _unicode_swapcase) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/bltinmodule.c#L1781-L1824 # noqa: E501 +@overload(ord) +def ol_ord(c): + if isinstance(c, types.UnicodeType): + + def impl(c): + lc = len(c) + if lc != 1: + # CPython does TypeError + raise TypeError("ord() expected a character") + return _get_code_point(c, 0) + + return impl + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L2005-L2028 # noqa: E501 +# This looks a bit different to the cpython implementation but, with the +# exception of a latin1 fast path is logically the same. It finds the "kind" of +# the codepoint `ch`, creates a length 1 string of that kind and then injects +# the code point into the zero position of that string. Cpython does similar but +# branches for each kind (this is encapsulated in Numba's _set_code_point). +@register_jitable +def _unicode_char(ch): + assert ch <= _MAX_UNICODE + kind = _codepoint_to_kind(ch) + ret = _empty_string(kind, 1, kind == PY_UNICODE_1BYTE_KIND) + _set_code_point(ret, 0, ch) + return ret + + +_out_of_range_msg = "chr() arg not in range(0x%hx)" % _MAX_UNICODE + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodeobject.c#L3045-L3055 # noqa: E501 +@register_jitable +def _PyUnicode_FromOrdinal(ordinal): + if ordinal < 0 or ordinal > _MAX_UNICODE: + raise ValueError(_out_of_range_msg) + + return _unicode_char(_Py_UCS4(ordinal)) + + +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/bltinmodule.c#L715-L720 # noqa: E501 +@overload(chr) +def ol_chr(i): + if isinstance(i, types.Integer): + + def impl(i): + return _PyUnicode_FromOrdinal(i) + + return impl + + +@overload_method(types.UnicodeType, "__str__") +def unicode_str(s): + return lambda s: s + + +@overload_method(types.UnicodeType, "__repr__") +def unicode_repr(s): + # Can't use f-string as the impl ends up calling str and then repr, which + # then recurses somewhere in imports. + return lambda s: "'" + s + "'" + + +@overload_method(types.Integer, "__str__") +def integer_str(n): + ten = n(10) + + def impl(n): + flag = False + if n < 0: + n = -n + flag = True + if n == 0: + return "0" + length = flag + 1 + int(np.floor(np.log10(n))) + kind = PY_UNICODE_1BYTE_KIND + char_width = _kind_to_byte_width(kind) + s = _malloc_string(kind, char_width, length, True) + if flag: + _set_code_point(s, 0, ord("-")) + idx = length - 1 + while n > 0: + n, digit = divmod(n, ten) + c = ord("0") + digit + _set_code_point(s, idx, c) + idx -= 1 + return s + + return impl + + +@overload_method(types.Integer, "__repr__") +def integer_repr(n): + return lambda n: n.__str__() + + +@overload_method(types.Boolean, "__repr__") +@overload_method(types.Boolean, "__str__") +def boolean_str(b): + return lambda b: "True" if b else "False" + + +# ------------------------------------------------------------------------------ +# iteration +# ------------------------------------------------------------------------------ + + +@lower("getiter", types.UnicodeType) +def getiter_unicode(context, builder, sig, args): + [ty] = sig.args + [data] = args + + iterobj = context.make_helper(builder, sig.return_type) + + # set the index to zero + zero = context.get_constant(types.uintp, 0) + indexptr = cgutils.alloca_once_value(builder, zero) + + iterobj.index = indexptr + + # wire in the unicode type data + iterobj.data = data + + # incref as needed + if context.enable_nrt: + context.nrt.incref(builder, ty, data) + + res = iterobj._getvalue() + return impl_ret_new_ref(context, builder, sig.return_type, res) + + +@lower("iternext", types.UnicodeIteratorType) +# a new ref counted object is put into result._yield so set the new_ref to True! +@iternext_impl(RefType.NEW) +def iternext_unicode(context, builder, sig, args, result): + [iterty] = sig.args + [iter] = args + + tyctx = context.typing_context + + # get ref to unicode.__getitem__ + fnty = tyctx.resolve_value_type(operator.getitem) + getitem_sig = fnty.get_call_type( + tyctx, (types.unicode_type, types.uintp), {} + ) + getitem_impl = context.get_function(fnty, getitem_sig) + + # get ref to unicode.__len__ + fnty = tyctx.resolve_value_type(len) + len_sig = fnty.get_call_type(tyctx, (types.unicode_type,), {}) + len_impl = context.get_function(fnty, len_sig) + + # grab unicode iterator struct + iterobj = context.make_helper(builder, iterty, value=iter) + + # find the length of the string + strlen = len_impl(builder, (iterobj.data,)) + + # find the current index + index = builder.load(iterobj.index) + + # see if the index is in range + is_valid = builder.icmp_unsigned("<", index, strlen) + result.set_valid(is_valid) + + with builder.if_then(is_valid): + # return value at index + gotitem = getitem_impl( + builder, + ( + iterobj.data, + index, + ), + ) + result.yield_(gotitem) + + # bump index for next cycle + nindex = cgutils.increment_index(builder, index) + builder.store(nindex, iterobj.index) diff --git a/numba_cuda/numba/cuda/cpython/unicode_support.py b/numba_cuda/numba/cuda/cpython/unicode_support.py new file mode 100644 index 000000000..710e5d611 --- /dev/null +++ b/numba_cuda/numba/cuda/cpython/unicode_support.py @@ -0,0 +1,1597 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-2-Clause + +""" +This module contains support functions for more advanced unicode operations. +This is not a public API and is for Numba internal use only. Most of the +functions are relatively straightforward translations of the functions with the +same name in CPython. +""" + +from collections import namedtuple +from enum import IntEnum + +import llvmlite.ir +import numpy as np + +from numba.core import types +from numba.cuda import cgutils +from numba.core.imputils import impl_ret_untracked + +from numba.core.extending import overload, register_jitable +from numba.cuda.extending import intrinsic +from numba.core.errors import TypingError + +# This is equivalent to the struct `_PyUnicode_TypeRecord defined in CPython's +# Objects/unicodectype.c +typerecord = namedtuple("typerecord", "upper lower title decimal digit flags") + +# The Py_UCS4 type from CPython: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/unicodeobject.h#L112 # noqa: E501 +_Py_UCS4 = types.uint32 + +# ------------------------------------------------------------------------------ +# Start code related to/from CPython's unicodectype impl +# +# NOTE: the original source at: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c # noqa: E501 +# contains this statement: +# +# /* +# Unicode character type helpers. +# +# Written by Marc-Andre Lemburg (mal@lemburg.com). +# Modified for Python 2.0 by Fredrik Lundh (fredrik@pythonware.com) +# +# Copyright (c) Corporation for National Research Initiatives. +# +# */ + + +# This enum contains the values defined in CPython's Objects/unicodectype.c that +# provide masks for use against the various members of the typerecord +# +# See: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L13-L27 # noqa: E501 +# + + +_Py_TAB = 0x9 +_Py_LINEFEED = 0xA +_Py_CARRIAGE_RETURN = 0xD +_Py_SPACE = 0x20 + + +class _PyUnicode_TyperecordMasks(IntEnum): + ALPHA_MASK = 0x01 + DECIMAL_MASK = 0x02 + DIGIT_MASK = 0x04 + LOWER_MASK = 0x08 + LINEBREAK_MASK = 0x10 + SPACE_MASK = 0x20 + TITLE_MASK = 0x40 + UPPER_MASK = 0x80 + XID_START_MASK = 0x100 + XID_CONTINUE_MASK = 0x200 + PRINTABLE_MASK = 0x400 + NUMERIC_MASK = 0x800 + CASE_IGNORABLE_MASK = 0x1000 + CASED_MASK = 0x2000 + EXTENDED_CASE_MASK = 0x4000 + + +def _PyUnicode_gettyperecord(a): + raise RuntimeError("Calling the Python definition is invalid") + + +@intrinsic +def _gettyperecord_impl(typingctx, codepoint): + """ + Provides the binding to numba_gettyperecord, returns a `typerecord` + namedtuple of properties from the codepoint. + """ + if not isinstance(codepoint, types.Integer): + raise TypingError("codepoint must be an integer") + + def details(context, builder, signature, args): + ll_void = context.get_value_type(types.void) + ll_Py_UCS4 = context.get_value_type(_Py_UCS4) + ll_intc = context.get_value_type(types.intc) + ll_intc_ptr = ll_intc.as_pointer() + ll_uchar = context.get_value_type(types.uchar) + ll_uchar_ptr = ll_uchar.as_pointer() + ll_ushort = context.get_value_type(types.ushort) + ll_ushort_ptr = ll_ushort.as_pointer() + fnty = llvmlite.ir.FunctionType( + ll_void, + [ + ll_Py_UCS4, # code + ll_intc_ptr, # upper + ll_intc_ptr, # lower + ll_intc_ptr, # title + ll_uchar_ptr, # decimal + ll_uchar_ptr, # digit + ll_ushort_ptr, # flags + ], + ) + fn = cgutils.get_or_insert_function( + builder.module, fnty, name="numba_gettyperecord" + ) + upper = cgutils.alloca_once(builder, ll_intc, name="upper") + lower = cgutils.alloca_once(builder, ll_intc, name="lower") + title = cgutils.alloca_once(builder, ll_intc, name="title") + decimal = cgutils.alloca_once(builder, ll_uchar, name="decimal") + digit = cgutils.alloca_once(builder, ll_uchar, name="digit") + flags = cgutils.alloca_once(builder, ll_ushort, name="flags") + + byref = [upper, lower, title, decimal, digit, flags] + builder.call(fn, [args[0]] + byref) + buf = [] + for x in byref: + buf.append(builder.load(x)) + + res = context.make_tuple(builder, signature.return_type, tuple(buf)) + return impl_ret_untracked(context, builder, signature.return_type, res) + + tupty = types.NamedTuple( + [ + types.intc, + types.intc, + types.intc, + types.uchar, + types.uchar, + types.ushort, + ], + typerecord, + ) + sig = tupty(_Py_UCS4) + return sig, details + + +@overload(_PyUnicode_gettyperecord) +def gettyperecord_impl(a): + """ + Provides a _PyUnicode_gettyperecord binding, for convenience it will accept + single character strings and code points. + """ + if isinstance(a, types.UnicodeType): + from numba.cpython.unicode import _get_code_point + + def impl(a): + if len(a) > 1: + msg = "gettyperecord takes a single unicode character" + raise ValueError(msg) + code_point = _get_code_point(a, 0) + data = _gettyperecord_impl(_Py_UCS4(code_point)) + return data + + return impl + if isinstance(a, types.Integer): + return lambda a: _gettyperecord_impl(_Py_UCS4(a)) + + +# whilst it's possible to grab the _PyUnicode_ExtendedCase symbol as it's global +# it is safer to use a defined api: +@intrinsic +def _PyUnicode_ExtendedCase(typingctx, index): + """ + Accessor function for the _PyUnicode_ExtendedCase array, binds to + numba_get_PyUnicode_ExtendedCase which wraps the array and does the lookup + """ + if not isinstance(index, types.Integer): + raise TypingError("Expected an index") + + def details(context, builder, signature, args): + ll_Py_UCS4 = context.get_value_type(_Py_UCS4) + ll_intc = context.get_value_type(types.intc) + fnty = llvmlite.ir.FunctionType(ll_Py_UCS4, [ll_intc]) + fn = cgutils.get_or_insert_function( + builder.module, fnty, name="numba_get_PyUnicode_ExtendedCase" + ) + return builder.call(fn, [args[0]]) + + sig = _Py_UCS4(types.intc) + return sig, details + + +# The following functions are replications of the functions with the same name +# in CPython's Objects/unicodectype.c + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L64-L71 # noqa: E501 +@register_jitable +def _PyUnicode_ToTitlecase(ch): + ctype = _PyUnicode_gettyperecord(ch) + if ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK: + return _PyUnicode_ExtendedCase(ctype.title & 0xFFFF) + return ch + ctype.title + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L76-L81 # noqa: E501 +@register_jitable +def _PyUnicode_IsTitlecase(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.TITLE_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L86-L91 # noqa: E501 +@register_jitable +def _PyUnicode_IsXidStart(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.XID_START_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L96-L101 # noqa: E501 +@register_jitable +def _PyUnicode_IsXidContinue(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.XID_CONTINUE_MASK != 0 + + +@register_jitable +def _PyUnicode_ToDecimalDigit(ch): + ctype = _PyUnicode_gettyperecord(ch) + if ctype.flags & _PyUnicode_TyperecordMasks.DECIMAL_MASK: + return ctype.decimal + return -1 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L123-L1128 # noqa: E501 +@register_jitable +def _PyUnicode_ToDigit(ch): + ctype = _PyUnicode_gettyperecord(ch) + if ctype.flags & _PyUnicode_TyperecordMasks.DIGIT_MASK: + return ctype.digit + return -1 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L140-L145 # noqa: E501 +@register_jitable +def _PyUnicode_IsNumeric(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.NUMERIC_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L160-L165 # noqa: E501 +@register_jitable +def _PyUnicode_IsPrintable(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.PRINTABLE_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L170-L175 # noqa: E501 +@register_jitable +def _PyUnicode_IsLowercase(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.LOWER_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L180-L185 # noqa: E501 +@register_jitable +def _PyUnicode_IsUppercase(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.UPPER_MASK != 0 + + +@register_jitable +def _PyUnicode_IsLineBreak(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.LINEBREAK_MASK != 0 + + +@register_jitable +def _PyUnicode_ToUppercase(ch): + raise NotImplementedError + + +@register_jitable +def _PyUnicode_ToLowercase(ch): + raise NotImplementedError + + +# From: https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodectype.c#L211-L225 # noqa: E501 +@register_jitable +def _PyUnicode_ToLowerFull(ch, res): + ctype = _PyUnicode_gettyperecord(ch) + if ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK: + index = ctype.lower & 0xFFFF + n = ctype.lower >> 24 + for i in range(n): + res[i] = _PyUnicode_ExtendedCase(index + i) + return n + res[0] = ch + ctype.lower + return 1 + + +# From: https://github.com/python/cpython/blob/201c8f79450628241574fba940e08107178dc3a5/Objects/unicodectype.c#L227-L241 # noqa: E501 +@register_jitable +def _PyUnicode_ToTitleFull(ch, res): + ctype = _PyUnicode_gettyperecord(ch) + if ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK: + index = ctype.title & 0xFFFF + n = ctype.title >> 24 + for i in range(n): + res[i] = _PyUnicode_ExtendedCase(index + i) + return n + res[0] = ch + ctype.title + return 1 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L243-L257 # noqa: E501 +@register_jitable +def _PyUnicode_ToUpperFull(ch, res): + ctype = _PyUnicode_gettyperecord(ch) + if ctype.flags & _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK: + index = ctype.upper & 0xFFFF + n = ctype.upper >> 24 + for i in range(n): + # Perhaps needed to use unicode._set_code_point() here + res[i] = _PyUnicode_ExtendedCase(index + i) + return n + res[0] = ch + ctype.upper + return 1 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L259-L272 # noqa: E501 +@register_jitable +def _PyUnicode_ToFoldedFull(ch, res): + ctype = _PyUnicode_gettyperecord(ch) + extended_case_mask = _PyUnicode_TyperecordMasks.EXTENDED_CASE_MASK + if ctype.flags & extended_case_mask and (ctype.lower >> 20) & 7: + index = (ctype.lower & 0xFFFF) + (ctype.lower >> 24) + n = (ctype.lower >> 20) & 7 + for i in range(n): + res[i] = _PyUnicode_ExtendedCase(index + i) + return n + return _PyUnicode_ToLowerFull(ch, res) + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L274-L279 # noqa: E501 +@register_jitable +def _PyUnicode_IsCased(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.CASED_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L281-L286 # noqa: E501 +@register_jitable +def _PyUnicode_IsCaseIgnorable(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.CASE_IGNORABLE_MASK != 0 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L123-L135 # noqa: E501 +@register_jitable +def _PyUnicode_IsDigit(ch): + if _PyUnicode_ToDigit(ch) < 0: + return 0 + return 1 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L106-L118 # noqa: E501 +@register_jitable +def _PyUnicode_IsDecimalDigit(ch): + if _PyUnicode_ToDecimalDigit(ch) < 0: + return 0 + return 1 + + +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Objects/unicodectype.c#L291-L296 # noqa: E501 +@register_jitable +def _PyUnicode_IsSpace(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.SPACE_MASK != 0 + + +@register_jitable +def _PyUnicode_IsAlpha(ch): + ctype = _PyUnicode_gettyperecord(ch) + return ctype.flags & _PyUnicode_TyperecordMasks.ALPHA_MASK != 0 + + +# End code related to/from CPython's unicodectype impl +# ------------------------------------------------------------------------------ + + +# ------------------------------------------------------------------------------ +# Start code related to/from CPython's pyctype + + +# From the definition in CPython's Include/pyctype.h +# From: https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L5-L11 # noqa: E501 +class _PY_CTF(IntEnum): + LOWER = 0x01 + UPPER = 0x02 + ALPHA = 0x01 | 0x02 + DIGIT = 0x04 + ALNUM = 0x01 | 0x02 | 0x04 + SPACE = 0x08 + XDIGIT = 0x10 + + +# From the definition in CPython's Python/pyctype.c +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L5 # noqa: E501 +_Py_ctype_table = np.array( + [ + 0, # 0x0 '\x00' + 0, # 0x1 '\x01' + 0, # 0x2 '\x02' + 0, # 0x3 '\x03' + 0, # 0x4 '\x04' + 0, # 0x5 '\x05' + 0, # 0x6 '\x06' + 0, # 0x7 '\x07' + 0, # 0x8 '\x08' + _PY_CTF.SPACE, # 0x9 '\t' + _PY_CTF.SPACE, # 0xa '\n' + _PY_CTF.SPACE, # 0xb '\v' + _PY_CTF.SPACE, # 0xc '\f' + _PY_CTF.SPACE, # 0xd '\r' + 0, # 0xe '\x0e' + 0, # 0xf '\x0f' + 0, # 0x10 '\x10' + 0, # 0x11 '\x11' + 0, # 0x12 '\x12' + 0, # 0x13 '\x13' + 0, # 0x14 '\x14' + 0, # 0x15 '\x15' + 0, # 0x16 '\x16' + 0, # 0x17 '\x17' + 0, # 0x18 '\x18' + 0, # 0x19 '\x19' + 0, # 0x1a '\x1a' + 0, # 0x1b '\x1b' + 0, # 0x1c '\x1c' + 0, # 0x1d '\x1d' + 0, # 0x1e '\x1e' + 0, # 0x1f '\x1f' + _PY_CTF.SPACE, # 0x20 ' ' + 0, # 0x21 '!' + 0, # 0x22 '"' + 0, # 0x23 '#' + 0, # 0x24 '$' + 0, # 0x25 '%' + 0, # 0x26 '&' + 0, # 0x27 "'" + 0, # 0x28 '(' + 0, # 0x29 ')' + 0, # 0x2a '*' + 0, # 0x2b '+' + 0, # 0x2c ',' + 0, # 0x2d '-' + 0, # 0x2e '.' + 0, # 0x2f '/' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x30 '0' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x31 '1' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x32 '2' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x33 '3' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x34 '4' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x35 '5' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x36 '6' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x37 '7' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x38 '8' + _PY_CTF.DIGIT | _PY_CTF.XDIGIT, # 0x39 '9' + 0, # 0x3a ':' + 0, # 0x3b ';' + 0, # 0x3c '<' + 0, # 0x3d '=' + 0, # 0x3e '>' + 0, # 0x3f '?' + 0, # 0x40 '@' + _PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x41 'A' + _PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x42 'B' + _PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x43 'C' + _PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x44 'D' + _PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x45 'E' + _PY_CTF.UPPER | _PY_CTF.XDIGIT, # 0x46 'F' + _PY_CTF.UPPER, # 0x47 'G' + _PY_CTF.UPPER, # 0x48 'H' + _PY_CTF.UPPER, # 0x49 'I' + _PY_CTF.UPPER, # 0x4a 'J' + _PY_CTF.UPPER, # 0x4b 'K' + _PY_CTF.UPPER, # 0x4c 'L' + _PY_CTF.UPPER, # 0x4d 'M' + _PY_CTF.UPPER, # 0x4e 'N' + _PY_CTF.UPPER, # 0x4f 'O' + _PY_CTF.UPPER, # 0x50 'P' + _PY_CTF.UPPER, # 0x51 'Q' + _PY_CTF.UPPER, # 0x52 'R' + _PY_CTF.UPPER, # 0x53 'S' + _PY_CTF.UPPER, # 0x54 'T' + _PY_CTF.UPPER, # 0x55 'U' + _PY_CTF.UPPER, # 0x56 'V' + _PY_CTF.UPPER, # 0x57 'W' + _PY_CTF.UPPER, # 0x58 'X' + _PY_CTF.UPPER, # 0x59 'Y' + _PY_CTF.UPPER, # 0x5a 'Z' + 0, # 0x5b '[' + 0, # 0x5c '\\' + 0, # 0x5d ']' + 0, # 0x5e '^' + 0, # 0x5f '_' + 0, # 0x60 '`' + _PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x61 'a' + _PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x62 'b' + _PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x63 'c' + _PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x64 'd' + _PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x65 'e' + _PY_CTF.LOWER | _PY_CTF.XDIGIT, # 0x66 'f' + _PY_CTF.LOWER, # 0x67 'g' + _PY_CTF.LOWER, # 0x68 'h' + _PY_CTF.LOWER, # 0x69 'i' + _PY_CTF.LOWER, # 0x6a 'j' + _PY_CTF.LOWER, # 0x6b 'k' + _PY_CTF.LOWER, # 0x6c 'l' + _PY_CTF.LOWER, # 0x6d 'm' + _PY_CTF.LOWER, # 0x6e 'n' + _PY_CTF.LOWER, # 0x6f 'o' + _PY_CTF.LOWER, # 0x70 'p' + _PY_CTF.LOWER, # 0x71 'q' + _PY_CTF.LOWER, # 0x72 'r' + _PY_CTF.LOWER, # 0x73 's' + _PY_CTF.LOWER, # 0x74 't' + _PY_CTF.LOWER, # 0x75 'u' + _PY_CTF.LOWER, # 0x76 'v' + _PY_CTF.LOWER, # 0x77 'w' + _PY_CTF.LOWER, # 0x78 'x' + _PY_CTF.LOWER, # 0x79 'y' + _PY_CTF.LOWER, # 0x7a 'z' + 0, # 0x7b '{' + 0, # 0x7c '|' + 0, # 0x7d '}' + 0, # 0x7e '~' + 0, # 0x7f '\x7f' + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + dtype=np.intc, +) + + +# From the definition in CPython's Python/pyctype.c +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L145 # noqa: E501 +_Py_ctype_tolower = np.array( + [ + 0x00, + 0x01, + 0x02, + 0x03, + 0x04, + 0x05, + 0x06, + 0x07, + 0x08, + 0x09, + 0x0A, + 0x0B, + 0x0C, + 0x0D, + 0x0E, + 0x0F, + 0x10, + 0x11, + 0x12, + 0x13, + 0x14, + 0x15, + 0x16, + 0x17, + 0x18, + 0x19, + 0x1A, + 0x1B, + 0x1C, + 0x1D, + 0x1E, + 0x1F, + 0x20, + 0x21, + 0x22, + 0x23, + 0x24, + 0x25, + 0x26, + 0x27, + 0x28, + 0x29, + 0x2A, + 0x2B, + 0x2C, + 0x2D, + 0x2E, + 0x2F, + 0x30, + 0x31, + 0x32, + 0x33, + 0x34, + 0x35, + 0x36, + 0x37, + 0x38, + 0x39, + 0x3A, + 0x3B, + 0x3C, + 0x3D, + 0x3E, + 0x3F, + 0x40, + 0x61, + 0x62, + 0x63, + 0x64, + 0x65, + 0x66, + 0x67, + 0x68, + 0x69, + 0x6A, + 0x6B, + 0x6C, + 0x6D, + 0x6E, + 0x6F, + 0x70, + 0x71, + 0x72, + 0x73, + 0x74, + 0x75, + 0x76, + 0x77, + 0x78, + 0x79, + 0x7A, + 0x5B, + 0x5C, + 0x5D, + 0x5E, + 0x5F, + 0x60, + 0x61, + 0x62, + 0x63, + 0x64, + 0x65, + 0x66, + 0x67, + 0x68, + 0x69, + 0x6A, + 0x6B, + 0x6C, + 0x6D, + 0x6E, + 0x6F, + 0x70, + 0x71, + 0x72, + 0x73, + 0x74, + 0x75, + 0x76, + 0x77, + 0x78, + 0x79, + 0x7A, + 0x7B, + 0x7C, + 0x7D, + 0x7E, + 0x7F, + 0x80, + 0x81, + 0x82, + 0x83, + 0x84, + 0x85, + 0x86, + 0x87, + 0x88, + 0x89, + 0x8A, + 0x8B, + 0x8C, + 0x8D, + 0x8E, + 0x8F, + 0x90, + 0x91, + 0x92, + 0x93, + 0x94, + 0x95, + 0x96, + 0x97, + 0x98, + 0x99, + 0x9A, + 0x9B, + 0x9C, + 0x9D, + 0x9E, + 0x9F, + 0xA0, + 0xA1, + 0xA2, + 0xA3, + 0xA4, + 0xA5, + 0xA6, + 0xA7, + 0xA8, + 0xA9, + 0xAA, + 0xAB, + 0xAC, + 0xAD, + 0xAE, + 0xAF, + 0xB0, + 0xB1, + 0xB2, + 0xB3, + 0xB4, + 0xB5, + 0xB6, + 0xB7, + 0xB8, + 0xB9, + 0xBA, + 0xBB, + 0xBC, + 0xBD, + 0xBE, + 0xBF, + 0xC0, + 0xC1, + 0xC2, + 0xC3, + 0xC4, + 0xC5, + 0xC6, + 0xC7, + 0xC8, + 0xC9, + 0xCA, + 0xCB, + 0xCC, + 0xCD, + 0xCE, + 0xCF, + 0xD0, + 0xD1, + 0xD2, + 0xD3, + 0xD4, + 0xD5, + 0xD6, + 0xD7, + 0xD8, + 0xD9, + 0xDA, + 0xDB, + 0xDC, + 0xDD, + 0xDE, + 0xDF, + 0xE0, + 0xE1, + 0xE2, + 0xE3, + 0xE4, + 0xE5, + 0xE6, + 0xE7, + 0xE8, + 0xE9, + 0xEA, + 0xEB, + 0xEC, + 0xED, + 0xEE, + 0xEF, + 0xF0, + 0xF1, + 0xF2, + 0xF3, + 0xF4, + 0xF5, + 0xF6, + 0xF7, + 0xF8, + 0xF9, + 0xFA, + 0xFB, + 0xFC, + 0xFD, + 0xFE, + 0xFF, + ], + dtype=np.uint8, +) + + +# From the definition in CPython's Python/pyctype.c +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Python/pyctype.c#L180 +_Py_ctype_toupper = np.array( + [ + 0x00, + 0x01, + 0x02, + 0x03, + 0x04, + 0x05, + 0x06, + 0x07, + 0x08, + 0x09, + 0x0A, + 0x0B, + 0x0C, + 0x0D, + 0x0E, + 0x0F, + 0x10, + 0x11, + 0x12, + 0x13, + 0x14, + 0x15, + 0x16, + 0x17, + 0x18, + 0x19, + 0x1A, + 0x1B, + 0x1C, + 0x1D, + 0x1E, + 0x1F, + 0x20, + 0x21, + 0x22, + 0x23, + 0x24, + 0x25, + 0x26, + 0x27, + 0x28, + 0x29, + 0x2A, + 0x2B, + 0x2C, + 0x2D, + 0x2E, + 0x2F, + 0x30, + 0x31, + 0x32, + 0x33, + 0x34, + 0x35, + 0x36, + 0x37, + 0x38, + 0x39, + 0x3A, + 0x3B, + 0x3C, + 0x3D, + 0x3E, + 0x3F, + 0x40, + 0x41, + 0x42, + 0x43, + 0x44, + 0x45, + 0x46, + 0x47, + 0x48, + 0x49, + 0x4A, + 0x4B, + 0x4C, + 0x4D, + 0x4E, + 0x4F, + 0x50, + 0x51, + 0x52, + 0x53, + 0x54, + 0x55, + 0x56, + 0x57, + 0x58, + 0x59, + 0x5A, + 0x5B, + 0x5C, + 0x5D, + 0x5E, + 0x5F, + 0x60, + 0x41, + 0x42, + 0x43, + 0x44, + 0x45, + 0x46, + 0x47, + 0x48, + 0x49, + 0x4A, + 0x4B, + 0x4C, + 0x4D, + 0x4E, + 0x4F, + 0x50, + 0x51, + 0x52, + 0x53, + 0x54, + 0x55, + 0x56, + 0x57, + 0x58, + 0x59, + 0x5A, + 0x7B, + 0x7C, + 0x7D, + 0x7E, + 0x7F, + 0x80, + 0x81, + 0x82, + 0x83, + 0x84, + 0x85, + 0x86, + 0x87, + 0x88, + 0x89, + 0x8A, + 0x8B, + 0x8C, + 0x8D, + 0x8E, + 0x8F, + 0x90, + 0x91, + 0x92, + 0x93, + 0x94, + 0x95, + 0x96, + 0x97, + 0x98, + 0x99, + 0x9A, + 0x9B, + 0x9C, + 0x9D, + 0x9E, + 0x9F, + 0xA0, + 0xA1, + 0xA2, + 0xA3, + 0xA4, + 0xA5, + 0xA6, + 0xA7, + 0xA8, + 0xA9, + 0xAA, + 0xAB, + 0xAC, + 0xAD, + 0xAE, + 0xAF, + 0xB0, + 0xB1, + 0xB2, + 0xB3, + 0xB4, + 0xB5, + 0xB6, + 0xB7, + 0xB8, + 0xB9, + 0xBA, + 0xBB, + 0xBC, + 0xBD, + 0xBE, + 0xBF, + 0xC0, + 0xC1, + 0xC2, + 0xC3, + 0xC4, + 0xC5, + 0xC6, + 0xC7, + 0xC8, + 0xC9, + 0xCA, + 0xCB, + 0xCC, + 0xCD, + 0xCE, + 0xCF, + 0xD0, + 0xD1, + 0xD2, + 0xD3, + 0xD4, + 0xD5, + 0xD6, + 0xD7, + 0xD8, + 0xD9, + 0xDA, + 0xDB, + 0xDC, + 0xDD, + 0xDE, + 0xDF, + 0xE0, + 0xE1, + 0xE2, + 0xE3, + 0xE4, + 0xE5, + 0xE6, + 0xE7, + 0xE8, + 0xE9, + 0xEA, + 0xEB, + 0xEC, + 0xED, + 0xEE, + 0xEF, + 0xF0, + 0xF1, + 0xF2, + 0xF3, + 0xF4, + 0xF5, + 0xF6, + 0xF7, + 0xF8, + 0xF9, + 0xFA, + 0xFB, + 0xFC, + 0xFD, + 0xFE, + 0xFF, + ], + dtype=np.uint8, +) + + +class _PY_CTF_LB(IntEnum): + LINE_BREAK = 0x01 + LINE_FEED = 0x02 + CARRIAGE_RETURN = 0x04 + + +_Py_ctype_islinebreak = np.array( + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + _PY_CTF_LB.LINE_BREAK | _PY_CTF_LB.LINE_FEED, # 0xa '\n' + _PY_CTF_LB.LINE_BREAK, # 0xb '\v' + _PY_CTF_LB.LINE_BREAK, # 0xc '\f' + _PY_CTF_LB.LINE_BREAK | _PY_CTF_LB.CARRIAGE_RETURN, # 0xd '\r' + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + _PY_CTF_LB.LINE_BREAK, # 0x1c '\x1c' + _PY_CTF_LB.LINE_BREAK, # 0x1d '\x1d' + _PY_CTF_LB.LINE_BREAK, # 0x1e '\x1e' + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + _PY_CTF_LB.LINE_BREAK, # 0x85 '\x85' + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + dtype=np.intc, +) + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pymacro.h#L25 # noqa: E501 +@register_jitable +def _Py_CHARMASK(ch): + """ + Equivalent to the CPython macro `Py_CHARMASK()`, masks off all but the + lowest 256 bits of ch. + """ + return types.uint8(ch) & types.uint8(0xFF) + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L30 # noqa: E501 +@register_jitable +def _Py_TOUPPER(ch): + """ + Equivalent to the CPython macro `Py_TOUPPER()` converts an ASCII range + code point to the upper equivalent + """ + return _Py_ctype_toupper[_Py_CHARMASK(ch)] + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L29 # noqa: E501 +@register_jitable +def _Py_TOLOWER(ch): + """ + Equivalent to the CPython macro `Py_TOLOWER()` converts an ASCII range + code point to the lower equivalent + """ + return _Py_ctype_tolower[_Py_CHARMASK(ch)] + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L18 # noqa: E501 +@register_jitable +def _Py_ISLOWER(ch): + """ + Equivalent to the CPython macro `Py_ISLOWER()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.LOWER + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L19 # noqa: E501 +@register_jitable +def _Py_ISUPPER(ch): + """ + Equivalent to the CPython macro `Py_ISUPPER()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.UPPER + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L20 # noqa: E501 +@register_jitable +def _Py_ISALPHA(ch): + """ + Equivalent to the CPython macro `Py_ISALPHA()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.ALPHA + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L21 # noqa: E501 +@register_jitable +def _Py_ISDIGIT(ch): + """ + Equivalent to the CPython macro `Py_ISDIGIT()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.DIGIT + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L22 # noqa: E501 +@register_jitable +def _Py_ISXDIGIT(ch): + """ + Equivalent to the CPython macro `Py_ISXDIGIT()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.XDIGIT + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L23 # noqa: E501 +@register_jitable +def _Py_ISALNUM(ch): + """ + Equivalent to the CPython macro `Py_ISALNUM()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.ALNUM + + +# Translation of: +# https://github.com/python/cpython/blob/1d4b6ba19466aba0eb91c4ba01ba509acf18c723/Include/pyctype.h#L24 # noqa: E501 +@register_jitable +def _Py_ISSPACE(ch): + """ + Equivalent to the CPython macro `Py_ISSPACE()` + """ + return _Py_ctype_table[_Py_CHARMASK(ch)] & _PY_CTF.SPACE + + +@register_jitable +def _Py_ISLINEBREAK(ch): + """Check if character is ASCII line break""" + return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.LINE_BREAK + + +@register_jitable +def _Py_ISLINEFEED(ch): + """Check if character is line feed `\n`""" + return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.LINE_FEED + + +@register_jitable +def _Py_ISCARRIAGERETURN(ch): + """Check if character is carriage return `\r`""" + return _Py_ctype_islinebreak[_Py_CHARMASK(ch)] & _PY_CTF_LB.CARRIAGE_RETURN + + +# End code related to/from CPython's pyctype +# ------------------------------------------------------------------------------ diff --git a/numba_cuda/numba/cuda/target.py b/numba_cuda/numba/cuda/target.py index fa7fcc9a3..e0e5ee87a 100644 --- a/numba_cuda/numba/cuda/target.py +++ b/numba_cuda/numba/cuda/target.py @@ -155,11 +155,18 @@ def init(self): def load_additional_registries(self): # side effect of import needed for numba.cpython.*, numba.cuda.cpython.*, the builtins # registry is updated at import time. - from numba.cpython import tupleobj, slicing # noqa: F401 - from numba.cuda.cpython import numbers # noqa: F401 - from numba.cpython import rangeobj, iterators, enumimpl # noqa: F401 - from numba.cpython import unicode, charseq # noqa: F401 - from numba.cuda.cpython import cmathimpl, mathimpl + from numba.cpython import tupleobj # noqa: F401 + from numba.cuda.cpython import ( + numbers, + slicing, + iterators, + listobj, + unicode, + charseq, + cmathimpl, + mathimpl, + ) + from numba.cpython import rangeobj, enumimpl # noqa: F401 from numba.core import optional # noqa: F401 from numba.misc import cffiimpl from numba.np import arrayobj # noqa: F401 @@ -188,6 +195,11 @@ def load_additional_registries(self): self.install_registry(vector_types.impl_registry) self.install_registry(fp16.target_registry) self.install_registry(bf16.target_registry) + self.install_registry(slicing.registry) + self.install_registry(iterators.registry) + self.install_registry(listobj.registry) + self.install_registry(unicode.registry) + self.install_registry(charseq.registry) def codegen(self): return self._internal_codegen diff --git a/numba_cuda/numba/cuda/tests/nocuda/test_import.py b/numba_cuda/numba/cuda/tests/nocuda/test_import.py index fa6735d61..17e1c3074 100644 --- a/numba_cuda/numba/cuda/tests/nocuda/test_import.py +++ b/numba_cuda/numba/cuda/tests/nocuda/test_import.py @@ -29,6 +29,11 @@ def test_no_impl_import(self): "numba.cuda.cpython.numbers", "numba.cuda.cpython.cmathimpl", "numba.cuda.cpython.mathimpl", + "numba.cuda.cpython.slicing", + "numba.cuda.cpython.iterators", + "numba.cuda.cpython.listobj", + "numba.cuda.cpython.unicode", + "numba.cuda.cpython.charseq", "numba.core.optional", "numba.misc.gdb_hook", "numba.misc.literal", diff --git a/numba_cuda/numba/cuda/tests/support.py b/numba_cuda/numba/cuda/tests/support.py index c85597d53..fba6ce352 100644 --- a/numba_cuda/numba/cuda/tests/support.py +++ b/numba_cuda/numba/cuda/tests/support.py @@ -28,9 +28,9 @@ from numba.core.extending import ( typeof_impl, register_model, - unbox, NativeValue, ) +from numba.cuda.core.pythonapi import unbox from numba.core.datamodel.models import OpaqueModel from numba.cuda.np import numpy_support