diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index e660b586..3cbf5213 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -283,7 +283,7 @@ def test_datetime_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp" + assert cur.description[0][1] == "timestamp(6)" def test_datetime_with_utc_time_zone_query_param(trino_connection): @@ -295,7 +295,7 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp with time zone" + assert cur.description[0][1] == "timestamp(6) with time zone" def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): @@ -309,19 +309,19 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp with time zone" + assert cur.description[0][1] == "timestamp(6) with time zone" def test_datetime_with_named_time_zone_query_param(trino_connection): cur = trino_connection.cursor(experimental_python_types=True) - params = datetime(2020, 1, 1, 16, 43, 22, 320000, tzinfo=pytz.timezone('America/Los_Angeles')) + params = pytz.timezone('America/Los_Angeles').localize(datetime(2020, 1, 1, 16, 43, 22, 320000)) cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp with time zone" + assert cur.description[0][1] == "timestamp(6) with time zone" def test_datetime_with_trailing_zeros(trino_connection): @@ -371,7 +371,7 @@ def test_doubled_datetimes(trino_connection): cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern')) + assert rows[0][0] == pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0)) cur = trino_connection.cursor(experimental_python_types=True) @@ -380,7 +380,7 @@ def test_doubled_datetimes(trino_connection): cur.execute("SELECT ?", params=(params,)) rows = cur.fetchall() - assert rows[0][0] == datetime(2002, 10, 27, 1, 30, 0, tzinfo=pytz.timezone('US/Eastern')) + assert rows[0][0] == pytz.timezone('US/Eastern').localize(datetime(2002, 10, 27, 1, 30, 0)) def test_date_query_param(trino_connection): diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index 5749a820..47e6bc5c 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -1,5 +1,7 @@ import math +from datetime import timedelta, datetime, date, time import pytest +import pytz from decimal import Decimal import trino @@ -200,6 +202,159 @@ def test_digest(trino_connection): .execute() +def test_date(trino_connection): + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS DATE)", python=None) \ + .add_field(sql="DATE '2001-08-22'", python=date(2001, 8, 22)) \ + .add_field(sql="DATE '0001-01-01'", python=date(1, 1, 1)) \ + .add_field(sql="DATE '1582-10-04'", python=date(1582, 10, 4)) \ + .add_field(sql="DATE '1582-10-05'", python=date(1582, 10, 5)) \ + .add_field(sql="DATE '1582-10-14'", python=date(1582, 10, 14)) \ + .execute() + + +def test_time(trino_connection): + time_0 = time(1, 23, 45) + time_3 = time(1, 23, 45, 123000) + time_6 = time(1, 23, 45, 123456) + time_round = time(1, 23, 45, 123457) + + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS TIME)", python=None) \ + .add_field(sql="CAST(null AS TIME(0))", python=None) \ + .add_field(sql="CAST(null AS TIME(3))", python=None) \ + .add_field(sql="CAST(null AS TIME(6))", python=None) \ + .add_field(sql="CAST(null AS TIME(9))", python=None) \ + .add_field(sql="CAST(null AS TIME(12))", python=None) \ + .add_field(sql="CAST('01:23:45' AS TIME(0))", python=time_0) \ + .add_field(sql="TIME '01:23:45.123'", python=time_3) \ + .add_field(sql="CAST('01:23:45.123' AS TIME(3))", python=time_3) \ + .add_field(sql="CAST('01:23:45.123456' AS TIME(6))", python=time_6) \ + .add_field(sql="CAST('01:23:45.123456789' AS TIME(9))", python=time_round) \ + .add_field(sql="CAST('01:23:45.123456789123' AS TIME(12))", python=time_round) \ + .execute() + + +def test_time_with_timezone(trino_connection): + query_time_with_timezone(trino_connection, '-08:00') + query_time_with_timezone(trino_connection, '+08:00') + query_time_with_timezone(trino_connection, '+05:30') + + +def query_time_with_timezone(trino_connection, tz_str): + tz = datetime.strptime('+00:00', "%z").tzinfo + + hours_shift = int(tz_str[:3]) + minutes_shift = int(tz_str[4:]) + delta = timedelta(hours=hours_shift, minutes=minutes_shift) + + time_0 = (datetime(2, 1, 1, 11, 23, 45, 0) - delta).time().replace(tzinfo=tz) + time_3 = (datetime(2, 1, 1, 11, 23, 45, 123000) - delta).time().replace(tzinfo=tz) + time_6 = (datetime(2, 1, 1, 11, 23, 45, 123456) - delta).time().replace(tzinfo=tz) + time_round = (datetime(2, 1, 1, 11, 23, 45, 123457) - delta).time().replace(tzinfo=tz) + + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS TIME WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIME(0) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIME(3) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIME(6) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIME(9) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIME(12) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST('11:23:45 %s' AS TIME(0) WITH TIME ZONE)" % (tz_str), python=time_0) \ + .add_field(sql="TIME '11:23:45.123 %s'" % (tz_str), python=time_3) \ + .add_field(sql="CAST('11:23:45.123 %s' AS TIME(3) WITH TIME ZONE)" % (tz_str), python=time_3) \ + .add_field(sql="CAST('11:23:45.123456 %s' AS TIME(6) WITH TIME ZONE)" % (tz_str), python=time_6) \ + .add_field(sql="CAST('11:23:45.123456789 %s' AS TIME(9) WITH TIME ZONE)" % (tz_str), python=time_round) \ + .add_field(sql="CAST('11:23:45.123456789123 %s' AS TIME(12) WITH TIME ZONE)" % (tz_str), python=time_round) \ + .execute() + + +def test_timestamp(trino_connection): + timestamp_0 = datetime(2001, 8, 22, 1, 23, 45, 0) + timestamp_3 = datetime(2001, 8, 22, 1, 23, 45, 123000) + timestamp_6 = datetime(2001, 8, 22, 1, 23, 45, 123456) + timestamp_round = datetime(2001, 8, 22, 1, 23, 45, 123457) + timestamp_ce = datetime(1, 1, 1, 1, 23, 45, 123000) + timestamp_julian = datetime(1582, 10, 4, 1, 23, 45, 123000) + timestamp_during_switch = datetime(1582, 10, 5, 1, 23, 45, 123000) + timestamp_gregorian = datetime(1582, 10, 14, 1, 23, 45, 123000) + + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS TIMESTAMP)", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(0))", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(3))", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(6))", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(9))", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(12))", python=None) \ + .add_field(sql="CAST('2001-08-22 01:23:45' AS TIMESTAMP(0))", python=timestamp_0) \ + .add_field(sql="TIMESTAMP '2001-08-22 01:23:45.123'", python=timestamp_3) \ + .add_field(sql="TIMESTAMP '0001-01-01 01:23:45.123'", python=timestamp_ce) \ + .add_field(sql="TIMESTAMP '1582-10-04 01:23:45.123'", python=timestamp_julian) \ + .add_field(sql="TIMESTAMP '1582-10-05 01:23:45.123'", python=timestamp_during_switch) \ + .add_field(sql="TIMESTAMP '1582-10-14 01:23:45.123'", python=timestamp_gregorian) \ + .add_field(sql="CAST('2001-08-22 01:23:45.123' AS TIMESTAMP(3))", python=timestamp_3) \ + .add_field(sql="CAST('2001-08-22 01:23:45.123456' AS TIMESTAMP(6))", python=timestamp_6) \ + .add_field(sql="CAST('2001-08-22 01:23:45.123456111' AS TIMESTAMP(9))", python=timestamp_6) \ + .add_field(sql="CAST('2001-08-22 01:23:45.123456789' AS TIMESTAMP(9))", python=timestamp_round) \ + .add_field(sql="CAST('2001-08-22 01:23:45.123456111111' AS TIMESTAMP(12))", python=timestamp_6) \ + .add_field(sql="CAST('2001-08-22 01:23:45.123456789123' AS TIMESTAMP(12))", python=timestamp_round) \ + .execute() + + +def test_timestamp_with_timezone(trino_connection): + query_timestamp_with_timezone(trino_connection, '-08:00') + query_timestamp_with_timezone(trino_connection, '+08:00') + query_timestamp_with_timezone(trino_connection, '+05:30') + query_timestamp_with_timezone(trino_connection, 'US/Eastern') + query_timestamp_with_timezone(trino_connection, 'Asia/Kolkata') + query_timestamp_with_timezone(trino_connection, 'GMT') + + +def query_timestamp_with_timezone(trino_connection, tz_str): + if tz_str.startswith('+') or tz_str.startswith('-'): + hours_shift = int(tz_str[:3]) + minutes_shift = int(tz_str[4:]) + else: + tz = pytz.timezone(tz_str) + offset = tz.utcoffset(datetime.now()) + offset_seconds = offset.total_seconds() + hours_shift = int(offset_seconds / 3600) + minutes_shift = offset_seconds % 3600 / 60 + + tz = pytz.timezone('Etc/GMT') + delta = timedelta(hours=hours_shift, minutes=minutes_shift) + + timestamp_0 = tz.localize(datetime(2001, 8, 22, 11, 23, 45, 0)) - delta + timestamp_3 = tz.localize(datetime(2001, 8, 22, 11, 23, 45, 123000)) - delta + timestamp_6 = tz.localize(datetime(2001, 8, 22, 11, 23, 45, 123456)) - delta + timestamp_round = tz.localize(datetime(2001, 8, 22, 11, 23, 45, 123457)) - delta + + SqlTest(trino_connection) \ + .add_field(sql="CAST(null AS TIMESTAMP WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(0) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(3) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(6) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(9) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST(null AS TIMESTAMP(12) WITH TIME ZONE)", python=None) \ + .add_field(sql="CAST('2001-08-22 11:23:45 %s' AS TIMESTAMP(0) WITH TIME ZONE)" % (tz_str), + python=timestamp_0) \ + .add_field(sql="TIMESTAMP '2001-08-22 11:23:45.123 %s'" % (tz_str), + python=timestamp_3) \ + .add_field(sql="CAST('2001-08-22 11:23:45.123 %s' AS TIMESTAMP(3) WITH TIME ZONE)" % (tz_str), + python=timestamp_3) \ + .add_field(sql="CAST('2001-08-22 11:23:45.123456 %s' AS TIMESTAMP(6) WITH TIME ZONE)" % (tz_str), + python=timestamp_6) \ + .add_field(sql="CAST('2001-08-22 11:23:45.123456111 %s' AS TIMESTAMP(9) WITH TIME ZONE)" % (tz_str), + python=timestamp_6) \ + .add_field(sql="CAST('2001-08-22 11:23:45.123456789 %s' AS TIMESTAMP(9) WITH TIME ZONE)" % (tz_str), + python=timestamp_round) \ + .add_field(sql="CAST('2001-08-22 11:23:45.123456111111 %s' AS TIMESTAMP(12) WITH TIME ZONE)" % (tz_str), + python=timestamp_6) \ + .add_field(sql="CAST('2001-08-22 11:23:45.123456789123 %s' AS TIMESTAMP(12) WITH TIME ZONE)" % (tz_str), + python=timestamp_round) \ + .execute() + + class SqlTest: def __init__(self, trino_connection): self.cur = trino_connection.cursor(experimental_python_types=True) diff --git a/trino/client.py b/trino/client.py index 0073b3e3..220e1b70 100644 --- a/trino/client.py +++ b/trino/client.py @@ -41,9 +41,10 @@ import threading import time import urllib.parse -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, time as tim from decimal import Decimal from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Callable import pytz import requests @@ -740,7 +741,14 @@ def execute(self, additional_http_headers=None) -> TrinoResult: if self.cancelled: raise exceptions.TrinoUserError("Query has been cancelled", self.query_id) - response = self._request.post(self._sql, additional_http_headers) + if self._experimental_python_types: + http_headers = {constants.HEADER_CLIENT_CAPABILITIES: 'PARAMETRIC_DATETIME'} + if additional_http_headers: + http_headers.update(additional_http_headers) + else: + http_headers = additional_http_headers + + response = self._request.post(self._sql, http_headers) status = self._request.process(response) self._info_uri = status.info_uri self.query_id = status.id @@ -858,7 +866,7 @@ class RowMapperFactory: """ no_op_row_mapper = NoOpRowMapper() - def create(self, columns, experimental_python_types): + def create(self, columns: List[Dict[str, Any]], experimental_python_types: bool) -> Any: assert columns is not None if experimental_python_types: @@ -879,9 +887,9 @@ def _col_func(self, column): elif col_type.startswith('double') or col_type.startswith('real'): return self._double_map_func() elif col_type.startswith('timestamp'): - return self._timestamp_map_func(column, col_type) + return TimestampValueMapperFactory().create(column, col_type) elif col_type.startswith('time'): - return self._time_map_func(column, col_type) + return TimeValueMapperFactory().create(column, col_type) elif col_type == 'date': return lambda val: datetime.strptime(val, '%Y-%m-%d').date() else: @@ -906,67 +914,143 @@ def _double_map_func(self): else NAN if val == 'NaN' \ else float(val) - def _timestamp_map_func(self, column, col_type): - datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds) - pattern = "%Y-%m-%d %H:%M:%S" - ms_size, ms_to_trim = self._get_number_of_digits(column) - if ms_size > 0: - pattern += ".%f" - dt_size = datetime_default_size + ms_size - ms_to_trim - dt_tz_offset = datetime_default_size + ms_size - if 'with time zone' in col_type: +class AbstractTemporalValueMapperFactory: + def _get_number_of_millis_digits(self, column: Dict[str, Any]) -> int: + args = column['arguments'] + if len(args) == 0: + return 3 + ms_size = column['arguments'][0]['value'] + if ms_size == 0: + return -1 + return ms_size - if ms_to_trim > 0: - return lambda val: \ - [datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern + ' %z') - if tz.startswith('+') or tz.startswith('-') - else datetime.strptime(dt[:dt_size] + dt[dt_tz_offset:], pattern) - .replace(tzinfo=pytz.timezone(tz)) - for dt, tz in [val.rsplit(' ', 1)]][0] - else: - return lambda val: [datetime.strptime(val, pattern + ' %z') - if tz.startswith('+') or tz.startswith('-') - else datetime.strptime(dt, pattern).replace(tzinfo=pytz.timezone(tz)) - for dt, tz in [val.rsplit(' ', 1)]][0] + def _get_number_of_millis_digits_to_trim(self, column: Dict[str, Any], number_millis_digits: int) -> int: + args = column['arguments'] + if len(args) == 0: + return 1 + return (10 ** (number_millis_digits - min(number_millis_digits, 6))) - if ms_to_trim > 0: - return lambda val: datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern) - else: - return lambda val: datetime.strptime(val, pattern) - def _time_map_func(self, column, col_type): - pattern = "%H:%M:%S" - ms_size, ms_to_trim = self._get_number_of_digits(column) - if ms_size > 0: - pattern += ".%f" +class TimeValueMapperFactory(AbstractTemporalValueMapperFactory): + def create(self, column: Dict[str, Any], col_type: str) -> Callable[[Any], tim]: + datetime_default_size = 9 # size of 'HH:MM:SS.' + time_format = "%H:%M:%S" + millis_length = self._get_number_of_millis_digits(column) + millis_to_trim_div = self._get_number_of_millis_digits_to_trim(column, millis_length) + + if millis_length > 0: + time_format += ".%f" - time_size = 9 + ms_size - ms_to_trim + time_size = datetime_default_size + millis_length if 'with time zone' in col_type: - return lambda val: self._get_time_with_timezome(val, time_size, pattern) + if millis_to_trim_div > 1: + return self._map_time_timezone_trim_millis_digits(datetime_default_size, millis_to_trim_div, + time_format) + else: + return self._map_time_timezone(time_size, time_format) else: - return lambda val: datetime.strptime(val[:time_size], pattern).time() + if millis_to_trim_div > 1: + return self._map_time_trim_millis_digits(datetime_default_size, millis_to_trim_div, time_format) + else: + return self._map_time(time_size, time_format) + + def _map_time_timezone_trim_millis_digits(self, datetime_default_size: int, millis_to_trim_div: int, + time_format: str) -> Callable[[Any], tim]: + return lambda val: self._get_time_with_timezone_round_ms(val, + datetime_default_size, + millis_to_trim_div, + time_format) - def _get_time_with_timezome(self, value, time_size, pattern): - matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value) + def _map_time_timezone(self, time_size: int, time_format: str) -> Callable[[Any], tim]: + return lambda val: self._get_time_with_timezone(val, time_size, time_format) + + def _map_time_trim_millis_digits(self, datetime_default_size: int, millis_to_trim_div: int, + time_format: str) -> Callable[[Any], tim]: + return lambda val: datetime.strptime(val[:datetime_default_size] + + str(round(int(val[datetime_default_size:]) / millis_to_trim_div)), + time_format).time() + + def _map_time(self, time_size: int, time_format: str) -> Callable[[Any], tim]: + return lambda val: datetime.strptime(val[:time_size], time_format).time() + + def _get_time_with_timezone(self, value: str, time_size: int, pattern: str) -> tim: + matches = re.match(r'^(?P