Skip to content

Commit 1bc8ac0

Browse files
add support for operator.eq and operator.ne for TextEncodingNone
1 parent c67f46f commit 1bc8ac0

File tree

5 files changed

+146
-16
lines changed

5 files changed

+146
-16
lines changed

rbc/heavyai/buffer.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ class OmnisciBufferType(typesystem.Type):
5151
def pass_by_value(self):
5252
return False
5353

54+
@property
55+
def numba_pointer_type(self):
56+
return BufferPointer
57+
5458
@classmethod
5559
def preprocess_args(cls, args):
5660
assert len(args) == 1, args
@@ -78,11 +82,11 @@ def tonumba(self, bool_is_int8=None):
7882
*extra_members
7983
)
8084
buffer_type._params['NumbaType'] = BufferType
81-
buffer_type._params['NumbaPointerType'] = BufferPointer
85+
buffer_type._params['NumbaPointerType'] = self.numba_pointer_type
8286
numba_type = buffer_type.tonumba(bool_is_int8=True)
8387
if self.pass_by_value:
8488
return numba_type
85-
return BufferPointer(numba_type)
89+
return self.numba_pointer_type(numba_type)
8690

8791

8892
class BufferType(types.Type):

rbc/heavyai/pipeline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def define_pipelines(self):
158158
# namely the "nopython" pipeline
159159
pm = DefaultPassBuilder.define_nopython_pipeline(self.state)
160160
# Add the new pass to run after IRProcessing
161-
# pm.add_pass_after(AutoFreeBuffers, IRProcessing)
161+
pm.add_pass_after(AutoFreeBuffers, IRProcessing)
162162
pm.add_pass_after(CheckRaiseStmts, IRProcessing)
163163
pm.add_pass_after(DTypeComparison, ReconstructSSA)
164164
# finalize

rbc/heavyai/text_encoding_none.py

+84-8
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,28 @@
11
'''Omnisci TextEncodingNone type that corresponds to Omnisci type TEXT ENCODED NONE.
22
'''
33

4-
__all__ = ['TextEncodingNonePointer', 'TextEncodingNone', 'OmnisciTextEncodingNoneType']
4+
__all__ = ['TextEncodingNonePointer', 'TextEncodingNone', 'HeavyDBTextEncodingNoneType']
55

6+
import operator
67
from rbc import typesystem
8+
from rbc.targetinfo import TargetInfo
9+
from rbc.errors import RequireLiteralValue
710
from .buffer import (
811
BufferPointer, Buffer, OmnisciBufferType,
912
omnisci_buffer_constructor)
10-
from numba.core import types, extending
13+
from numba.core import types, extending, cgutils
14+
from llvmlite import ir
15+
from typing import Union
1116

1217

13-
class OmnisciTextEncodingNoneType(OmnisciBufferType):
18+
class HeavyDBTextEncodingNoneType(OmnisciBufferType):
1419
"""Omnisci TextEncodingNone type for RBC typesystem.
1520
"""
1621

22+
@property
23+
def numba_pointer_type(self):
24+
return TextEncodingNonePointer
25+
1726
@classmethod
1827
def preprocess_args(cls, args):
1928
element_type = typesystem.Type.fromstring('char8')
@@ -32,7 +41,8 @@ def match(self, other):
3241
return 2
3342

3443

35-
TextEncodingNonePointer = BufferPointer
44+
class TextEncodingNonePointer(BufferPointer):
45+
pass
3646

3747

3848
class TextEncodingNone(Buffer):
@@ -53,25 +63,91 @@ class TextEncodingNone(Buffer):
5363
5464
from rbc.heavydb import TextEncodingNone
5565
56-
@omnisci('TextEncodingNone(int32, int32)')
66+
@heavydb('TextEncodingNone(int32, int32)')
5767
def make_abc(first, n):
5868
r = TextEncodingNone(n)
5969
for i in range(n):
6070
r[i] = first + i
6171
return r
72+
73+
74+
.. code-block:: python
75+
76+
from rbc.heavydb import TextEncodingNone
77+
@heavydb('TextEncodingNone()')
78+
def make_text():
79+
return TextEncodingNone('some text here')
80+
6281
'''
6382

64-
def __init__(self, size: int):
83+
def __init__(self, size: Union[int, str]):
6584
pass
6685

6786

87+
@extending.overload(operator.eq)
88+
def text_encoding_none_eq(a, b):
89+
if isinstance(a, TextEncodingNonePointer) and isinstance(b, TextEncodingNonePointer):
90+
91+
def impl(a, b):
92+
if len(a) != len(b):
93+
return False
94+
for i in range(0, len(a)):
95+
if a[i] != b[i]:
96+
return False
97+
return True
98+
return impl
99+
elif isinstance(a, TextEncodingNonePointer) and isinstance(b, types.StringLiteral):
100+
lv = b.literal_value
101+
sz = len(lv)
102+
103+
def impl(a, b):
104+
if len(a) != sz:
105+
return False
106+
t = TextEncodingNone(lv)
107+
return a == t
108+
return impl
109+
110+
111+
@extending.overload(operator.ne)
112+
def text_encoding_none_ne(a, b):
113+
if isinstance(a, TextEncodingNonePointer):
114+
if isinstance(b, (TextEncodingNonePointer, types.StringLiteral)):
115+
def impl(a, b):
116+
return not(a == b)
117+
return impl
118+
119+
68120
@extending.lower_builtin(TextEncodingNone, types.Integer)
69121
def omnisci_text_encoding_none_constructor(context, builder, sig, args):
70122
return omnisci_buffer_constructor(context, builder, sig, args)
71123

72124

125+
@extending.lower_builtin(TextEncodingNone, types.StringLiteral)
126+
def omnisci_text_encoding_none_constructor_literal(context, builder, sig, args):
127+
int64_t = ir.IntType(64)
128+
int8_t_ptr = ir.IntType(8).as_pointer()
129+
130+
literal_value = sig.args[0].literal_value
131+
sz = int64_t(len(literal_value))
132+
133+
# arr = {ptr, size, is_null}*
134+
arr = omnisci_buffer_constructor(context, builder, sig.return_type(types.int64), [sz])
135+
ptr = builder.extract_value(builder.load(arr), [0])
136+
137+
msg_bytes = literal_value.encode('utf-8')
138+
msg_const = cgutils.make_bytearray(msg_bytes)
139+
msg_global_var = cgutils.global_constant(builder.module, f"Text({literal_value})", msg_const)
140+
msg_ptr = builder.bitcast(msg_global_var, int8_t_ptr)
141+
sizeof_char = TargetInfo().sizeof('char')
142+
cgutils.raw_memcpy(builder, ptr, msg_ptr, sz, sizeof_char)
143+
return arr
144+
145+
73146
@extending.type_callable(TextEncodingNone)
74147
def type_omnisci_text_encoding_none(context):
75-
def typer(size):
76-
return typesystem.Type.fromobject('TextEncodingNone').tonumba()
148+
def typer(arg):
149+
if isinstance(arg, types.UnicodeType):
150+
raise RequireLiteralValue()
151+
if isinstance(arg, (types.Integer, types.StringLiteral)):
152+
return typesystem.Type.fromobject('TextEncodingNone').tonumba()
77153
return typer

rbc/heavydb.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .thrift import Client as ThriftClient
1515
from . import heavyai
1616
from .heavyai import (
17-
OmnisciArrayType, OmnisciTextEncodingNoneType, OmnisciTextEncodingDictType,
17+
OmnisciArrayType, HeavyDBTextEncodingNoneType, OmnisciTextEncodingDictType,
1818
OmnisciOutputColumnType, OmnisciColumnType,
1919
OmnisciCompilerPipeline, OmnisciCursorType,
2020
BufferMeta, OmnisciColumnListType, OmnisciTableFunctionManagerType)
@@ -398,7 +398,7 @@ def add(a, b):
398398
Constant='int32|sizer=Constant',
399399
PreFlight='int32|sizer=PreFlight',
400400
ColumnList='OmnisciColumnListType',
401-
TextEncodingNone='OmnisciTextEncodingNoneType',
401+
TextEncodingNone='HeavyDBTextEncodingNoneType',
402402
TextEncodingDict='OmnisciTextEncodingDictType',
403403
TableFunctionManager='OmnisciTableFunctionManagerType<>',
404404
UDTF='int32|kind=UDTF'
@@ -894,7 +894,7 @@ def _get_ext_arguments_map(self):
894894
ext_arguments_map['OmnisciOutputColumnListType<%s>' % ptr_type] \
895895
= ext_arguments_map.get('ColumnList<%s>' % T)
896896

897-
ext_arguments_map['OmnisciTextEncodingNoneType<char8>'] = \
897+
ext_arguments_map['HeavyDBTextEncodingNoneType<char8>'] = \
898898
ext_arguments_map.get('TextEncodingNone')
899899

900900
values = list(ext_arguments_map.values())
@@ -1398,7 +1398,7 @@ def format_type(self, typ: typesystem.Type):
13981398
elif isinstance(typ, OmnisciCursorType):
13991399
p = tuple(map(self.format_type, typ[0]))
14001400
typ = typesystem.Type(('Cursor',) + p, **typ._params)
1401-
elif isinstance(typ, OmnisciTextEncodingNoneType):
1401+
elif isinstance(typ, HeavyDBTextEncodingNoneType):
14021402
typ = typ.copy().params(typename='TextEncodingNone')
14031403
use_typename = True
14041404
elif isinstance(typ, OmnisciTextEncodingDictType):
@@ -1444,7 +1444,7 @@ def remote_call(self, func, ftype: typesystem.Type, arguments: tuple, hold=False
14441444

14451445
if isinstance(atype, (OmnisciColumnType, OmnisciColumnListType)):
14461446
args.append(f'CURSOR({a})')
1447-
elif isinstance(atype, OmnisciTextEncodingNoneType):
1447+
elif isinstance(atype, HeavyDBTextEncodingNoneType):
14481448
if isinstance(a, bytes):
14491449
a = repr(a.decode())
14501450
elif isinstance(a, str):

rbc/tests/heavyai/test_text.py

+50
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,53 @@ def myupper(s):
135135

136136
for n, u in result:
137137
assert n.upper() == u
138+
139+
140+
def test_TextEncodingNone_str_constructor(heavydb):
141+
heavydb.reset()
142+
143+
@heavydb('TextEncodingNone(int32)')
144+
def constructor(_):
145+
return TextEncodingNone('hello world')
146+
147+
assert constructor(3).execute() == 'hello world'
148+
149+
150+
def test_TextEncodingNone_eq(heavydb):
151+
152+
heavydb.reset()
153+
154+
@heavydb('int32(TextEncodingNone, TextEncodingNone)')
155+
def eq1(a, b):
156+
return a == b
157+
158+
@heavydb('int32(TextEncodingNone)')
159+
def eq2(a):
160+
return a == 'world'
161+
162+
assert eq1('hello', 'hello').execute() == 1
163+
assert eq1('c', 'c').execute() == 1
164+
assert eq1('hello', 'h').execute() == 0
165+
assert eq1('hello', 'hello2').execute() == 0
166+
167+
assert eq2('world').execute() == 1
168+
169+
170+
def test_TextEncodingNone_ne(heavydb):
171+
172+
heavydb.reset()
173+
174+
@heavydb('int32(TextEncodingNone, TextEncodingNone)')
175+
def ne1(a, b):
176+
return a != b
177+
178+
@heavydb('int32(TextEncodingNone)')
179+
def ne2(a):
180+
return a != 'world'
181+
182+
assert ne1('hello', 'hello').execute() == 0
183+
assert ne1('c', 'c').execute() == 0
184+
assert ne1('hello', 'h').execute() == 1
185+
assert ne1('hello', 'hello2').execute() == 1
186+
187+
assert ne2('world').execute() == 0

0 commit comments

Comments
 (0)