Skip to content

Commit 111131c

Browse files
committed
Further updated the type integration tests
1 parent eba160b commit 111131c

File tree

2 files changed

+21
-83
lines changed

2 files changed

+21
-83
lines changed

tests/integration/test_types_integration.py

Lines changed: 3 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import math
22
import pytest
33
from decimal import Decimal
4-
import datetime
5-
from datetime import timezone, timedelta
64
import trino
75

86

@@ -91,6 +89,7 @@ def test_decimal(trino_connection):
9189
.add_field(sql="CAST('234.123456789123456789' AS DECIMAL(18,4))", python=Decimal('234.1235')) \
9290
.add_field(sql="CAST('10.3' AS DECIMAL(38,1))", python=Decimal('10.3')) \
9391
.add_field(sql="CAST('0.123456789123456789' AS DECIMAL(18,2))", python=Decimal('0.12')) \
92+
.add_field(sql="CAST('0.3123' AS DECIMAL(38,38))", python=Decimal('0.3123')) \
9493
.execute()
9594

9695

@@ -117,7 +116,9 @@ def test_char(trino_connection):
117116
def test_varbinary(trino_connection):
118117
SqlTest(trino_connection) \
119118
.add_field(sql="X'65683F'", python='ZWg/') \
119+
.add_field(sql="X'0001020304050607080DF9367AA7000000'", python='AAECAwQFBgcIDfk2eqcAAAA=') \
120120
.add_field(sql="CAST('' AS VARBINARY)", python='') \
121+
.add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \
121122
.add_field(sql="CAST(null AS VARBINARY)", python=None) \
122123
.execute()
123124

@@ -130,80 +131,6 @@ def test_json(trino_connection):
130131
.execute()
131132

132133

133-
def test_datetime(trino_connection):
134-
the_tz = datetime.timezone(datetime.timedelta(days=-1, seconds=57600))
135-
136-
SqlTest(trino_connection) \
137-
.add_field(sql="DATE '2001-08-22'", python=datetime.date(2001, 8, 22)) \
138-
.add_field(sql="DATE '02001-08-22'", python=datetime.date(2001, 8, 22)) \
139-
.add_field(sql="CAST(null AS DATE)", python=None) \
140-
.add_field(sql="TIME '01:23:45.123'", python=datetime.time(1, 23, 45, 123000)) \
141-
.add_field(sql="CAST(null AS TIME)", python=None) \
142-
.add_field(sql="CAST('01:23:45' AS TIME(0))", python=datetime.time(1, 23, 45)) \
143-
.add_field(sql="CAST(null AS TIME(0))", python=None) \
144-
.add_field(sql="CAST('01:23:45.123' AS TIME(3))", python=datetime.time(1, 23, 45, 123000)) \
145-
.add_field(sql="CAST(null AS TIME(3))", python=None) \
146-
.add_field(sql="CAST('01:23:45.123456' AS TIME(6))", python=datetime.time(1, 23, 45, 123000)) \
147-
.add_field(sql="CAST(null AS TIME(6))", python=None) \
148-
.add_field(sql="CAST('01:23:45.123456789' AS TIME(9))", python=datetime.time(1, 23, 45, 123000)) \
149-
.add_field(sql="CAST(null AS TIME(9))", python=None) \
150-
.add_field(sql="CAST('01:23:45.123456789123' AS TIME(12))", python=datetime.time(1, 23, 45, 123000)) \
151-
.add_field(sql="CAST(null AS TIME(12))", python=None) \
152-
.add_field(sql="TIME '01:23:45.123 -08:00'",
153-
python=datetime.time(1, 23, 45, 123000).replace(tzinfo=timezone(-timedelta(hours=8)))) \
154-
.add_field(sql="CAST(null AS TIME WITH TIME ZONE)", python=None) \
155-
.add_field(sql="CAST(null AS TIME(0) WITH TIME ZONE)", python=None) \
156-
.add_field(sql="CAST('01:23:45.123 -08:00' AS TIME(3) WITH TIME ZONE)",
157-
python=datetime.time(1, 23, 45, 123000).replace(tzinfo=timezone(-timedelta(hours=8)))) \
158-
.add_field(sql="CAST(null AS TIME(3) WITH TIME ZONE)", python=None) \
159-
.add_field(sql="CAST('01:23:45.123456 -08:00' AS TIME(6) WITH TIME ZONE)",
160-
python=datetime.time(1, 23, 45, 123000).replace(tzinfo=timezone(-timedelta(hours=8)))) \
161-
.add_field(sql="CAST(null AS TIME(6) WITH TIME ZONE)", python=None) \
162-
.add_field(sql="CAST('01:23:45.123456789 -08:00' AS TIME(9) WITH TIME ZONE)",
163-
python=datetime.time(1, 23, 45, 123000).replace(tzinfo=timezone(-timedelta(hours=8)))) \
164-
.add_field(sql="CAST(null AS TIME(9) WITH TIME ZONE)", python=None) \
165-
.add_field(sql="CAST('01:23:45.123456789123 -08:00' AS TIME(12) WITH TIME ZONE)",
166-
python=datetime.time(1, 23, 45, 123000).replace(tzinfo=timezone(-timedelta(hours=8)))) \
167-
.add_field(sql="CAST(null AS TIME(12) WITH TIME ZONE)", python=None) \
168-
.add_field(sql="TIMESTAMP '2001-08-22 01:23:45.123'",
169-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000)) \
170-
.add_field(sql="CAST(null AS TIMESTAMP)", python=None) \
171-
.add_field(sql="CAST('2001-08-22 01:23:45' AS TIMESTAMP(0))",
172-
python=datetime.datetime(2001, 8, 22, 1, 23, 45)) \
173-
.add_field(sql="CAST(null AS TIMESTAMP(0))", python=None) \
174-
.add_field(sql="CAST('2001-08-22 01:23:45.123' AS TIMESTAMP(3))",
175-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000)) \
176-
.add_field(sql="CAST(null AS TIMESTAMP(3))", python=None) \
177-
.add_field(sql="CAST('2001-08-22 01:23:45.123456' AS TIMESTAMP(6))",
178-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000)) \
179-
.add_field(sql="CAST(null AS TIMESTAMP(6))", python=None) \
180-
.add_field(sql="CAST('2001-08-22 01:23:45.123456789' AS TIMESTAMP(9))",
181-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000)) \
182-
.add_field(sql="CAST(null AS TIMESTAMP(9))", python=None) \
183-
.add_field(sql="CAST('2001-08-22 01:23:45.123456789123' AS TIMESTAMP(12))",
184-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000)) \
185-
.add_field(sql="CAST(null AS TIMESTAMP(12))", python=None) \
186-
.add_field(sql="TIMESTAMP '2001-08-22 01:23:45.123 -08:00'",
187-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000, tzinfo=the_tz)) \
188-
.add_field(sql="CAST(null AS TIMESTAMP WITH TIME ZONE)", python=None) \
189-
.add_field(sql="CAST('2001-08-22 01:23:45 -08:00' AS TIMESTAMP(0) WITH TIME ZONE)",
190-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, tzinfo=the_tz)) \
191-
.add_field(sql="CAST(null AS TIMESTAMP(0) WITH TIME ZONE)", python=None) \
192-
.add_field(sql="CAST('2001-08-22 01:23:45.123 -08:00' AS TIMESTAMP(3) WITH TIME ZONE)",
193-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000, tzinfo=the_tz)) \
194-
.add_field(sql="CAST(null AS TIMESTAMP(3) WITH TIME ZONE)", python=None) \
195-
.add_field(sql="CAST('2001-08-22 01:23:45.123456 -08:00' AS TIMESTAMP(6) WITH TIME ZONE)",
196-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000, tzinfo=the_tz)) \
197-
.add_field(sql="CAST(null AS TIMESTAMP(6) WITH TIME ZONE)", python=None) \
198-
.add_field(sql="CAST('2001-08-22 01:23:45.123456789 -08:00' AS TIMESTAMP(9) WITH TIME ZONE)",
199-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000, tzinfo=the_tz)) \
200-
.add_field(sql="CAST(null AS TIMESTAMP(9) WITH TIME ZONE)", python=None) \
201-
.add_field(sql="CAST('2001-08-22 01:23:45.123456789123 -08:00' AS TIMESTAMP(12) WITH TIME ZONE)",
202-
python=datetime.datetime(2001, 8, 22, 1, 23, 45, 123000, tzinfo=the_tz)) \
203-
.add_field(sql="CAST(null AS TIMESTAMP(12) WITH TIME ZONE)", python=None) \
204-
.execute()
205-
206-
207134
def test_interval(trino_connection):
208135
SqlTest(trino_connection) \
209136
.add_field(sql="CAST(null AS INTERVAL YEAR TO MONTH)", python=None) \

trino/client.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -770,19 +770,29 @@ def decorated(*args, **kwargs):
770770
return wrapper
771771

772772

773+
class NoOpRowMapper:
774+
"""
775+
No-op RowMapper which does not perform any transformation
776+
"""
777+
778+
def map(self, rows):
779+
return rows
780+
781+
773782
class RowMapperFactory:
774783
"""
775784
Given the 'columns' result from Trino, generate a list of
776785
lambda functions (one for each column) which will process a data value
777786
and returns a RowMapper instance which will process rows of data
778787
"""
788+
no_op_row_mapper = NoOpRowMapper()
779789

780790
def create(self, columns, experimental_python_types):
781791
assert columns is not None
782792

783793
if experimental_python_types:
784794
return RowMapper([self._col_func(column['typeSignature']) for column in columns])
785-
return RowMapper()
795+
return RowMapperFactory.no_op_row_mapper
786796

787797
def _col_func(self, column):
788798
col_type = column['rawType']
@@ -807,12 +817,12 @@ def _col_func(self, column):
807817
return lambda val: val
808818

809819
def _array_map_func(self, column):
810-
elt_mapping_func = self._col_func(column['arguments'][0]['value'])
811-
return lambda values: [elt_mapping_func(value) for value in values]
820+
element_mapping_func = self._col_func(column['arguments'][0]['value'])
821+
return lambda values: [element_mapping_func(value) for value in values]
812822

813823
def _row_map_func(self, column):
814-
elt_mapping_funcs = [self._col_func(arg['value']['typeSignature']) for arg in column['arguments']]
815-
return lambda values: tuple(elt_mapping_funcs[idx](value) for idx, value in enumerate(values))
824+
element_mapping_func = [self._col_func(arg['value']['typeSignature']) for arg in column['arguments']]
825+
return lambda values: tuple(element_mapping_func[idx](value) for idx, value in enumerate(values))
816826

817827
def _map_map_func(self, column):
818828
key_mapping_func = self._col_func(column['arguments'][0]['value'])
@@ -826,13 +836,14 @@ def _double_map_func(self):
826836
else float(val)
827837

828838
def _timestamp_map_func(self, column, col_type):
839+
datetime_default_size = 20
829840
pattern = "%Y-%m-%d %H:%M:%S"
830841
ms_size, ms_to_trim = self._get_number_of_digits(column)
831842
if ms_size > 0:
832843
pattern += ".%f"
833844

834-
dt_size = 20 + ms_size - ms_to_trim
835-
dt_tz_offset = 20 + ms_size
845+
dt_size = datetime_default_size + ms_size - ms_to_trim
846+
dt_tz_offset = datetime_default_size + ms_size
836847
if 'with time zone' in col_type:
837848

838849
if ms_to_trim > 0:

0 commit comments

Comments
 (0)