Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def test_datetime_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp"
assert cur.description[0][1] == "timestamp(6)"
Comment thread
lpoulain marked this conversation as resolved.


def test_datetime_with_utc_time_zone_query_param(trino_connection):
Expand All @@ -295,7 +295,7 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"
assert cur.description[0][1] == "timestamp(6) with time zone"


def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
Expand All @@ -309,19 +309,19 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection):
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"
assert cur.description[0][1] == "timestamp(6) with time zone"


def test_datetime_with_named_time_zone_query_param(trino_connection):
cur = trino_connection.cursor(experimental_python_types=True)

params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles'))
params = pytz.timezone('America/Los_Angeles').localize(datetime(2020, 1, 1, 16, 43, 22, 320000))

cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == params
assert cur.description[0][1] == "timestamp with time zone"
assert cur.description[0][1] == "timestamp(6) with time zone"


def test_datetime_with_trailing_zeros(trino_connection):
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_doubled_datetimes(trino_connection):
cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
assert rows[0][0] == pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0))

cur = trino_connection.cursor(experimental_python_types=True)

Expand All @@ -380,7 +380,7 @@ def test_doubled_datetimes(trino_connection):
cur.execute("SELECT ?", params=(params,))
rows = cur.fetchall()

assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern'))
assert rows[0][0] == pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0))


def test_date_query_param(trino_connection):
Expand Down
156 changes: 156 additions & 0 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import math
import datetime
from datetime import timedelta

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Full datetime module already imported above

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

import pytest
import pytz
from decimal import Decimal
import trino

Expand Down Expand Up @@ -200,6 +203,159 @@ def test_digest(trino_connection):
.execute()


def test_date(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="CAST(null AS DATE)", python=None) \
.add_field(sql="DATE '2001-08-22'", python=datetime.date(2001, 8, 22)) \
.add_field(sql="DATE '0001-01-01'", python=datetime.date(1, 1, 1)) \
.add_field(sql="DATE '1582-10-04'", python=datetime.date(1582, 10, 4)) \
.add_field(sql="DATE '1582-10-05'", python=datetime.date(1582, 10, 5)) \
.add_field(sql="DATE '1582-10-14'", python=datetime.date(1582, 10, 14)) \
.execute()


def test_time(trino_connection):
time_0 = datetime.time(1, 23, 45)
time_3 = datetime.time(1, 23, 45, 123000)
time_6 = datetime.time(1, 23, 45, 123456)
Comment thread
hashhar marked this conversation as resolved.
Outdated
time_round = datetime.time(1, 23, 45, 123457)

SqlTest(trino_connection) \
.add_field(sql="CAST(null AS TIME)", python=None) \
.add_field(sql="CAST(null AS TIME(0))", python=None) \
.add_field(sql="CAST(null AS TIME(3))", python=None) \
.add_field(sql="CAST(null AS TIME(6))", python=None) \
.add_field(sql="CAST(null AS TIME(9))", python=None) \
.add_field(sql="CAST(null AS TIME(12))", python=None) \
.add_field(sql="CAST('01:23:45' AS TIME(0))", python=time_0) \
.add_field(sql="TIME '01:23:45.123'", python=time_3) \
.add_field(sql="CAST('01:23:45.123' AS TIME(3))", python=time_3) \
.add_field(sql="CAST('01:23:45.123456' AS TIME(6))", python=time_6) \
.add_field(sql="CAST('01:23:45.123456789' AS TIME(9))", python=time_round) \
.add_field(sql="CAST('01:23:45.123456789123' AS TIME(12))", python=time_round) \
.execute()


def test_time_with_timezone(trino_connection):
query_time_with_timezone(trino_connection, '-08:00')
query_time_with_timezone(trino_connection, '+08:00')
query_time_with_timezone(trino_connection, '+05:30')
Comment thread
hashhar marked this conversation as resolved.


def query_time_with_timezone(trino_connection, tz_str):
tz = datetime.datetime.strptime('+00:00', "%z").tzinfo

hours_shift = int(tz_str[:3])
minutes_shift = int(tz_str[4:])
delta = timedelta(hours=hours_shift, minutes=minutes_shift)

time_0 = (datetime.datetime(2, 1, 1, 11, 23, 45, 0) - delta).time().replace(tzinfo=tz)
time_3 = (datetime.datetime(2, 1, 1, 11, 23, 45, 123000) - delta).time().replace(tzinfo=tz)
time_6 = (datetime.datetime(2, 1, 1, 11, 23, 45, 123456) - delta).time().replace(tzinfo=tz)
time_round = (datetime.datetime(2, 1, 1, 11, 23, 45, 123457) - delta).time().replace(tzinfo=tz)

SqlTest(trino_connection) \
.add_field(sql="CAST(null AS TIME WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIME(0) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIME(3) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIME(6) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIME(9) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIME(12) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST('11:23:45 %s' AS TIME(0) WITH TIME ZONE)" % (tz_str), python=time_0) \
.add_field(sql="TIME '11:23:45.123 %s'" % (tz_str), python=time_3) \
.add_field(sql="CAST('11:23:45.123 %s' AS TIME(3) WITH TIME ZONE)" % (tz_str), python=time_3) \
.add_field(sql="CAST('11:23:45.123456 %s' AS TIME(6) WITH TIME ZONE)" % (tz_str), python=time_6) \
.add_field(sql="CAST('11:23:45.123456789 %s' AS TIME(9) WITH TIME ZONE)" % (tz_str), python=time_round) \
.add_field(sql="CAST('11:23:45.123456789123 %s' AS TIME(12) WITH TIME ZONE)" % (tz_str), python=time_round) \
.execute()


def test_timestamp(trino_connection):
timestamp_0 = datetime.datetime(2001, 8, 22, 1, 23, 45, 0)
timestamp_3 = datetime.datetime(2001, 8, 22, 1, 23, 45, 123000)
timestamp_6 = datetime.datetime(2001, 8, 22, 1, 23, 45, 123456)
Comment thread
lpoulain marked this conversation as resolved.
Outdated
timestamp_round = datetime.datetime(2001, 8, 22, 1, 23, 45, 123457)
timestamp_ce = datetime.datetime(1, 1, 1, 1, 23, 45, 123000)
timestamp_julian = datetime.datetime(1582, 10, 4, 1, 23, 45, 123000)
timestamp_during_switch = datetime.datetime(1582, 10, 5, 1, 23, 45, 123000)
timestamp_gregorian = datetime.datetime(1582, 10, 14, 1, 23, 45, 123000)

SqlTest(trino_connection) \
.add_field(sql="CAST(null AS TIMESTAMP)", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(0))", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(3))", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(6))", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(9))", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(12))", python=None) \
.add_field(sql="CAST('2001-08-22 01:23:45' AS TIMESTAMP(0))", python=timestamp_0) \
.add_field(sql="TIMESTAMP '2001-08-22 01:23:45.123'", python=timestamp_3) \
.add_field(sql="TIMESTAMP '0001-01-01 01:23:45.123'", python=timestamp_ce) \
.add_field(sql="TIMESTAMP '1582-10-04 01:23:45.123'", python=timestamp_julian) \
.add_field(sql="TIMESTAMP '1582-10-05 01:23:45.123'", python=timestamp_during_switch) \
.add_field(sql="TIMESTAMP '1582-10-14 01:23:45.123'", python=timestamp_gregorian) \
.add_field(sql="CAST('2001-08-22 01:23:45.123' AS TIMESTAMP(3))", python=timestamp_3) \
.add_field(sql="CAST('2001-08-22 01:23:45.123456' AS TIMESTAMP(6))", python=timestamp_6) \
.add_field(sql="CAST('2001-08-22 01:23:45.123456111' AS TIMESTAMP(9))", python=timestamp_6) \
.add_field(sql="CAST('2001-08-22 01:23:45.123456789' AS TIMESTAMP(9))", python=timestamp_round) \
.add_field(sql="CAST('2001-08-22 01:23:45.123456111111' AS TIMESTAMP(12))", python=timestamp_6) \
.add_field(sql="CAST('2001-08-22 01:23:45.123456789123' AS TIMESTAMP(12))", python=timestamp_round) \
.execute()


def test_timestamp_with_timezone(trino_connection):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Squash "Refactor some test function names for consistency" since it changes something introduced in the first commit itself.

query_timestamp_with_timezone(trino_connection, '-08:00')
query_timestamp_with_timezone(trino_connection, '+08:00')
query_timestamp_with_timezone(trino_connection, '+05:30')
query_timestamp_with_timezone(trino_connection, 'US/Eastern')
query_timestamp_with_timezone(trino_connection, 'Asia/Kolkata')
query_timestamp_with_timezone(trino_connection, 'GMT')


def query_timestamp_with_timezone(trino_connection, tz_str):
if tz_str.startswith('+') or tz_str.startswith('-'):
hours_shift = int(tz_str[:3])
minutes_shift = int(tz_str[4:])
else:
tz = pytz.timezone(tz_str)
offset = tz.utcoffset(datetime.datetime.now())
offset_seconds = offset.total_seconds()
hours_shift = int(offset_seconds / 3600)
minutes_shift = offset_seconds % 3600 / 60

tz = pytz.timezone('Etc/GMT')
delta = timedelta(hours=hours_shift, minutes=minutes_shift)

timestamp_0 = tz.localize(datetime.datetime(2001, 8, 22, 11, 23, 45, 0)) - delta
timestamp_3 = tz.localize(datetime.datetime(2001, 8, 22, 11, 23, 45, 123000)) - delta
timestamp_6 = tz.localize(datetime.datetime(2001, 8, 22, 11, 23, 45, 123456)) - delta
timestamp_round = tz.localize(datetime.datetime(2001, 8, 22, 11, 23, 45, 123457)) - delta

SqlTest(trino_connection) \
.add_field(sql="CAST(null AS TIMESTAMP WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(0) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(3) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(6) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(9) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST(null AS TIMESTAMP(12) WITH TIME ZONE)", python=None) \
.add_field(sql="CAST('2001-08-22 11:23:45 %s' AS TIMESTAMP(0) WITH TIME ZONE)" % (tz_str),
python=timestamp_0) \
.add_field(sql="TIMESTAMP '2001-08-22 11:23:45.123 %s'" % (tz_str),
python=timestamp_3) \
.add_field(sql="CAST('2001-08-22 11:23:45.123 %s' AS TIMESTAMP(3) WITH TIME ZONE)" % (tz_str),
python=timestamp_3) \
.add_field(sql="CAST('2001-08-22 11:23:45.123456 %s' AS TIMESTAMP(6) WITH TIME ZONE)" % (tz_str),
python=timestamp_6) \
.add_field(sql="CAST('2001-08-22 11:23:45.123456111 %s' AS TIMESTAMP(9) WITH TIME ZONE)" % (tz_str),
python=timestamp_6) \
.add_field(sql="CAST('2001-08-22 11:23:45.123456789 %s' AS TIMESTAMP(9) WITH TIME ZONE)" % (tz_str),
python=timestamp_round) \
.add_field(sql="CAST('2001-08-22 11:23:45.123456111111 %s' AS TIMESTAMP(12) WITH TIME ZONE)" % (tz_str),
python=timestamp_6) \
.add_field(sql="CAST('2001-08-22 11:23:45.123456789123 %s' AS TIMESTAMP(12) WITH TIME ZONE)" % (tz_str),
python=timestamp_round) \
.execute()


class SqlTest:
def __init__(self, trino_connection):
self.cur = trino_connection.cursor(experimental_python_types=True)
Expand Down
94 changes: 65 additions & 29 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,14 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
if self.cancelled:
raise exceptions.TrinoUserError("Query has been cancelled", self.query_id)

response = self._request.post(self._sql, additional_http_headers)
if self._experimental_python_types:
http_headers = {constants.HEADER_CLIENT_CAPABILITIES: 'PARAMETRIC_DATETIME'}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another naive question - how come this doesn't require any other changes in how the client parses results from the server?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the current behavior, without setting this header?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The server just starts sending back the precision in the type signatures and the values have additional precision to match.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The client capabilities doesn't seem to be documented anywhere, but I found this comment in the server: https://github.com/trinodb/trino/blob/master/core/trino-main/src/main/java/io/trino/server/protocol/QueryResultRows.java#L355
So looks like it was rounded before in the server, and now it'll get truncated in the client, but still with higher precision (6 vs 3). So it's an improvement.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other Trino clients I've seen (JDBC, ODBC, Go) are all passing X-Trino-Client-Capabilityes: PARAMETRIC_DATETIME to Trino when executing a query. The Python library is the only one that doesn't. Unfortunately, Python datetimes only support a 6-digit precision max.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nineinchnick as far as how come no other change is required on how the client parses the results from the server, it's because in my last PR I made sure that the code worked with either type of result.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trino-go-client doesn't pass that header either: trinodb/trino-go-client#13

if additional_http_headers:
http_headers.update(additional_http_headers)
else:
http_headers = additional_http_headers

response = self._request.post(self._sql, http_headers)
status = self._request.process(response)
self._info_uri = status.info_uri
self.query_id = status.id
Expand Down Expand Up @@ -907,66 +914,95 @@ def _double_map_func(self):
else float(val)

def _timestamp_map_func(self, column, col_type):
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
datetime_default_size = len('YYYY-MM-DD HH:MM:SS.')
pattern = "%Y-%m-%d %H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
millis_length, millis_div = self._get_number_of_millis_digits(column)
if millis_length > 0:
pattern += ".%f"

dt_size = datetime_default_size + ms_size - ms_to_trim
dt_tz_offset = datetime_default_size + ms_size
timestamp_length = datetime_default_size
dt_tz_offset = datetime_default_size + millis_length
if 'with time zone' in col_type:

if ms_to_trim > 0:
if millis_div > 1:
return lambda val: \
[datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern + ' %z')
[datetime.strptime(val[:timestamp_length]
+ str(round(int(val[timestamp_length:dt_tz_offset]) / millis_div))
+ val[dt_tz_offset:], pattern + ' %z')
if tz.startswith('+') or tz.startswith('-')
else datetime.strptime(dt[:dt_size] + dt[dt_tz_offset:], pattern)
.replace(tzinfo=pytz.timezone(tz))
else pytz.timezone(tz).localize(datetime.strptime(dt[:timestamp_length]
+ str(round(int(val[timestamp_length:dt_tz_offset])
/ millis_div))
+ dt[dt_tz_offset:], pattern))
for dt, tz in [val.rsplit(' ', 1)]][0]
else:
return lambda val: [datetime.strptime(val, pattern + ' %z')
if tz.startswith('+') or tz.startswith('-')
else datetime.strptime(dt, pattern).replace(tzinfo=pytz.timezone(tz))
else pytz.timezone(tz).localize(datetime.strptime(dt, pattern))
for dt, tz in [val.rsplit(' ', 1)]][0]

if ms_to_trim > 0:
return lambda val: datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern)
if millis_div > 1:
return lambda val: datetime.strptime(val[:timestamp_length]
+ str(round(int(val[timestamp_length:dt_tz_offset]) / millis_div))
+ val[dt_tz_offset:], pattern)
else:
return lambda val: datetime.strptime(val, pattern)

def _time_map_func(self, column, col_type):
datetime_default_size = 9 # size of 'HH:MM:SS.'
Comment thread
lpoulain marked this conversation as resolved.
pattern = "%H:%M:%S"
ms_size, ms_to_trim = self._get_number_of_digits(column)
if ms_size > 0:
millis_length, millis_div = self._get_number_of_millis_digits(column)
if millis_length > 0:
pattern += ".%f"

time_size = 9 + ms_size - ms_to_trim
time_size = datetime_default_size + millis_length

if 'with time zone' in col_type:
return lambda val: self._get_time_with_timezome(val, time_size, pattern)
if millis_div > 1:
return lambda val: self._get_time_with_timezone_round_ms(val,
datetime_default_size,
millis_div,
pattern)
else:
return lambda val: self._get_time_with_timezone(val, time_size, pattern)
else:
if millis_div > 1:
return lambda val: datetime.strptime(val[:datetime_default_size]
+ str(round(int(val[datetime_default_size:]) / millis_div)),
pattern).time()
else:
return lambda val: datetime.strptime(val[:time_size], pattern).time()

def _get_time_with_timezone(self, value, time_size, pattern):
matches = re.match(r'^(?P<time>.*)(?P<sign>[\+\-])(?P<hours>\d{2}):(?P<minutes>\d{2})$', value)
assert matches is not None
assert len(matches.groups()) == 4
if matches.group('sign') == '-':
tz = -timedelta(hours=int(matches.group('hours')), minutes=int(matches.group('minutes')))
else:
return lambda val: datetime.strptime(val[:time_size], pattern).time()
tz = timedelta(hours=int(matches.group('hours')), minutes=int(matches.group('minutes')))
return datetime.strptime(matches.group('time')[:time_size], pattern).time().replace(tzinfo=timezone(tz))

def _get_time_with_timezome(self, value, time_size, pattern):
matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value)
def _get_time_with_timezone_round_ms(self, value, time_size, ms_div, pattern):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arguments sounds quite cryptic. Could we name them better?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

matches = re.match(r'^(?P<time>.*)(?P<sign>[\+\-])(?P<hours>\d{2}):(?P<minutes>\d{2})$', value)
assert matches is not None
assert len(matches.groups()) == 4
if matches.group(2) == '-':
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
if matches.group('sign') == '-':
tz = -timedelta(hours=int(matches.group('hours')), minutes=int(matches.group('minutes')))
else:
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
return datetime.strptime(matches.group(1)[:time_size], pattern).time().replace(tzinfo=timezone(tz))
tz = timedelta(hours=int(matches.group('hours')), minutes=int(matches.group('minutes')))
time_str = matches.group('time')[:time_size]
millis_str = str(round(int(matches.group('time')[time_size:]) / ms_div))
return datetime.strptime(time_str + millis_str, pattern).time().replace(tzinfo=timezone(tz))

def _get_number_of_digits(self, column):
def _get_number_of_millis_digits(self, column):
Comment thread
lpoulain marked this conversation as resolved.
Outdated
args = column['arguments']
if len(args) == 0:
return 3, 0
return 3, 1
ms_size = column['arguments'][0]['value']
if ms_size == 0:
return -1, 0
return -1, 1
ms_to_trim = ms_size - min(ms_size, 6)
return ms_size, ms_to_trim
return ms_size, 10 ** ms_to_trim


class RowMapper:
Expand Down
2 changes: 2 additions & 0 deletions trino/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@

HEADER_SET_SCHEMA = "X-Trino-Set-Schema"
HEADER_SET_CATALOG = "X-Trino-Set-Catalog"

HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"