|
| 1 | +'''HeavyDB TextEncodingNone type that corresponds to HeavyDB type TEXT ENCODED NONE. |
| 2 | +''' |
| 3 | + |
| 4 | +__all__ = ['TextEncodingNonePointer', 'TextEncodingNone', 'HeavyDBTextEncodingNoneType'] |
| 5 | + |
| 6 | +import operator |
| 7 | +from rbc import typesystem |
| 8 | +from rbc.targetinfo import TargetInfo |
| 9 | +from rbc.errors import RequireLiteralValue |
| 10 | +from .buffer import ( |
| 11 | + BufferPointer, Buffer, HeavyDBBufferType, |
| 12 | + heavydb_buffer_constructor) |
| 13 | +from numba.core import types, extending, cgutils |
| 14 | +from llvmlite import ir |
| 15 | +from typing import Union |
| 16 | + |
| 17 | + |
| 18 | +class HeavyDBTextEncodingNoneType(HeavyDBBufferType): |
| 19 | + """HeavyDB TextEncodingNone type for RBC typesystem. |
| 20 | + """ |
| 21 | + |
| 22 | + @property |
| 23 | + def numba_pointer_type(self): |
| 24 | + return TextEncodingNonePointer |
| 25 | + |
| 26 | + @classmethod |
| 27 | + def preprocess_args(cls, args): |
| 28 | + element_type = typesystem.Type.fromstring('char8') |
| 29 | + return ((element_type,),) |
| 30 | + |
| 31 | + @property |
| 32 | + def buffer_extra_members(self): |
| 33 | + return ('bool is_null',) |
| 34 | + |
| 35 | + def match(self, other): |
| 36 | + if type(self) is type(other): |
| 37 | + return self[0] == other[0] |
| 38 | + if other.is_pointer and other[0].is_char and other[0].bits == 8: |
| 39 | + return 1 |
| 40 | + if other.is_string: |
| 41 | + return 2 |
| 42 | + |
| 43 | + |
| 44 | +class TextEncodingNonePointer(BufferPointer): |
| 45 | + pass |
| 46 | + |
| 47 | + |
| 48 | +class TextEncodingNone(Buffer): |
| 49 | + '''HeavyDB TextEncodingNone type that corresponds to HeavyDB type TEXT ENCODED NONE. |
| 50 | +
|
| 51 | + HeavyDB TextEncodingNone represents the following structure: |
| 52 | +
|
| 53 | + .. code-block:: c |
| 54 | +
|
| 55 | + struct TextEncodingNone { |
| 56 | + char* ptr; |
| 57 | + size_t sz; // when non-negative, TextEncodingNone has fixed width. |
| 58 | + int8_t is_null; |
| 59 | + } |
| 60 | +
|
| 61 | +
|
| 62 | + .. code-block:: python |
| 63 | +
|
| 64 | + from rbc.heavydb import TextEncodingNone |
| 65 | +
|
| 66 | + @heavydb('TextEncodingNone(int32, int32)') |
| 67 | + def make_abc(first, n): |
| 68 | + r = TextEncodingNone(n) |
| 69 | + for i in range(n): |
| 70 | + r[i] = first + i |
| 71 | + return r |
| 72 | +
|
| 73 | +
|
| 74 | + .. code-block:: python |
| 75 | +
|
| 76 | + from rbc.heavydb import TextEncodingNone |
| 77 | + @heavydb('TextEncodingNone()') |
| 78 | + def make_text(): |
| 79 | + return TextEncodingNone('some text here') |
| 80 | +
|
| 81 | + ''' |
| 82 | + |
| 83 | + def __init__(self, size: Union[int, str]): |
| 84 | + pass |
| 85 | + |
| 86 | + |
| 87 | +@extending.overload(operator.eq) |
| 88 | +def text_encoding_none_eq(a, b): |
| 89 | + if isinstance(a, TextEncodingNonePointer) and isinstance(b, TextEncodingNonePointer): |
| 90 | + |
| 91 | + def impl(a, b): |
| 92 | + if len(a) != len(b): |
| 93 | + return False |
| 94 | + for i in range(0, len(a)): |
| 95 | + if a[i] != b[i]: |
| 96 | + return False |
| 97 | + return True |
| 98 | + return impl |
| 99 | + elif isinstance(a, TextEncodingNonePointer) and isinstance(b, types.StringLiteral): |
| 100 | + lv = b.literal_value |
| 101 | + sz = len(lv) |
| 102 | + |
| 103 | + def impl(a, b): |
| 104 | + if len(a) != sz: |
| 105 | + return False |
| 106 | + t = TextEncodingNone(lv) |
| 107 | + return a == t |
| 108 | + return impl |
| 109 | + |
| 110 | + |
| 111 | +@extending.overload(operator.ne) |
| 112 | +def text_encoding_none_ne(a, b): |
| 113 | + if isinstance(a, TextEncodingNonePointer): |
| 114 | + if isinstance(b, (TextEncodingNonePointer, types.StringLiteral)): |
| 115 | + def impl(a, b): |
| 116 | + return not(a == b) |
| 117 | + return impl |
| 118 | + |
| 119 | + |
| 120 | +@extending.lower_builtin(TextEncodingNone, types.Integer) |
| 121 | +def heavydb_text_encoding_none_constructor(context, builder, sig, args): |
| 122 | + return heavydb_buffer_constructor(context, builder, sig, args) |
| 123 | + |
| 124 | + |
| 125 | +@extending.lower_builtin(TextEncodingNone, types.StringLiteral) |
| 126 | +def heavydb_text_encoding_none_constructor_literal(context, builder, sig, args): |
| 127 | + int64_t = ir.IntType(64) |
| 128 | + int8_t_ptr = ir.IntType(8).as_pointer() |
| 129 | + |
| 130 | + literal_value = sig.args[0].literal_value |
| 131 | + sz = int64_t(len(literal_value)) |
| 132 | + |
| 133 | + # arr = {ptr, size, is_null}* |
| 134 | + arr = heavydb_buffer_constructor(context, builder, sig.return_type(types.int64), [sz]) |
| 135 | + ptr = builder.extract_value(builder.load(arr), [0]) |
| 136 | + |
| 137 | + msg_bytes = literal_value.encode('utf-8') |
| 138 | + msg_const = cgutils.make_bytearray(msg_bytes) |
| 139 | + msg_global_var = cgutils.global_constant(builder.module, f"Text({literal_value})", msg_const) |
| 140 | + msg_ptr = builder.bitcast(msg_global_var, int8_t_ptr) |
| 141 | + sizeof_char = TargetInfo().sizeof('char') |
| 142 | + cgutils.raw_memcpy(builder, ptr, msg_ptr, sz, sizeof_char) |
| 143 | + return arr |
| 144 | + |
| 145 | + |
| 146 | +@extending.type_callable(TextEncodingNone) |
| 147 | +def type_heavydb_text_encoding_none(context): |
| 148 | + def typer(arg): |
| 149 | + if isinstance(arg, types.UnicodeType): |
| 150 | + raise RequireLiteralValue() |
| 151 | + if isinstance(arg, (types.Integer, types.StringLiteral)): |
| 152 | + return typesystem.Type.fromobject('TextEncodingNone').tonumba() |
| 153 | + return typer |
0 commit comments