diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index 201921d2..a76dfb29 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -119,6 +119,7 @@ def test_none_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] is None + assert_cursor_description(cur, trino_type="unknown") def test_string_query_param(trino_connection): @@ -128,6 +129,7 @@ def test_string_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == "six'" + assert_cursor_description(cur, trino_type="varchar(4)", size=4) def test_execute_many(trino_connection): @@ -241,10 +243,11 @@ def test_legacy_primitive_types_with_connection_and_cursor( def test_decimal_query_param(trino_connection): cur = trino_connection.cursor() - cur.execute("SELECT ?", params=(Decimal('0.142857'),)) + cur.execute("SELECT ?", params=(Decimal('1112.142857'),)) rows = cur.fetchall() - assert rows[0][0] == Decimal('0.142857') + assert rows[0][0] == Decimal('1112.142857') + assert_cursor_description(cur, trino_type="decimal(10, 6)", precision=10, scale=6) def test_null_decimal(trino_connection): @@ -254,6 +257,7 @@ def test_null_decimal(trino_connection): rows = cur.fetchall() assert rows[0][0] is None + assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0) def test_biggest_decimal(trino_connection): @@ -264,6 +268,7 @@ def test_biggest_decimal(trino_connection): rows = cur.fetchall() assert rows[0][0] == params + assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0) def test_smallest_decimal(trino_connection): @@ -274,6 +279,7 @@ def test_smallest_decimal(trino_connection): rows = cur.fetchall() assert rows[0][0] == params + assert_cursor_description(cur, trino_type="decimal(38, 0)", precision=38, scale=0) def test_highest_precision_decimal(trino_connection): @@ -284,6 +290,7 @@ def test_highest_precision_decimal(trino_connection): rows = cur.fetchall() assert rows[0][0] == params + assert_cursor_description(cur, trino_type="decimal(38, 38)", precision=38, scale=38) def test_datetime_query_param(trino_connection): @@ -295,7 +302,7 @@ def test_datetime_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp(6)" + assert_cursor_description(cur, trino_type="timestamp(6)", precision=6) def test_datetime_with_utc_time_zone_query_param(trino_connection): @@ -307,7 +314,7 @@ def test_datetime_with_utc_time_zone_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp(6) with time zone" + assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6) def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): @@ -321,7 +328,7 @@ def test_datetime_with_numeric_offset_time_zone_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp(6) with time zone" + assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6) def test_datetime_with_named_time_zone_query_param(trino_connection): @@ -333,7 +340,7 @@ def test_datetime_with_named_time_zone_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "timestamp(6) with time zone" + assert_cursor_description(cur, trino_type="timestamp(6) with time zone", precision=6) def test_datetime_with_trailing_zeros(trino_connection): @@ -343,6 +350,7 @@ def test_datetime_with_trailing_zeros(trino_connection): rows = cur.fetchall() assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321000", "%Y-%m-%d %H:%M:%S.%f") + assert_cursor_description(cur, trino_type="timestamp(6)", precision=6) def test_null_datetime_with_time_zone(trino_connection): @@ -352,6 +360,7 @@ def test_null_datetime_with_time_zone(trino_connection): rows = cur.fetchall() assert rows[0][0] is None + assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3) def test_datetime_with_time_zone_numeric_offset(trino_connection): @@ -361,6 +370,7 @@ def test_datetime_with_time_zone_numeric_offset(trino_connection): rows = cur.fetchall() assert rows[0][0] == datetime.strptime("2001-08-22 03:04:05.321 -08:00", "%Y-%m-%d %H:%M:%S.%f %z") + assert_cursor_description(cur, trino_type="timestamp(3) with time zone", precision=3) def test_datetimes_with_time_zone_in_dst_gap_query_param(trino_connection): @@ -404,6 +414,7 @@ def test_date_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params + assert_cursor_description(cur, trino_type="date") def test_null_date(trino_connection): @@ -413,6 +424,7 @@ def test_null_date(trino_connection): rows = cur.fetchall() assert rows[0][0] is None + assert_cursor_description(cur, trino_type="date") def test_unsupported_python_dates(trino_connection): @@ -462,6 +474,16 @@ def test_supported_special_dates_query_param(trino_connection): assert rows[0][0] == params +def test_char(trino_connection): + cur = trino_connection.cursor() + + cur.execute("SELECT CHAR 'trino'") + rows = cur.fetchall() + + assert rows[0][0] == 'trino' + assert_cursor_description(cur, trino_type="char(5)", size=5) + + def test_time_query_param(trino_connection): cur = trino_connection.cursor() @@ -471,7 +493,7 @@ def test_time_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == params - assert cur.description[0][1] == "time(6)" + assert_cursor_description(cur, trino_type="time(6)", precision=6) def test_time_with_named_time_zone_query_param(trino_connection): @@ -501,7 +523,7 @@ def test_time(trino_connection): rows = cur.fetchall() assert rows[0][0] == time(1, 2, 3, 456000) - assert cur.description[0][1] == "time(3)" + assert_cursor_description(cur, trino_type="time(3)", precision=3) def test_null_time(trino_connection): @@ -511,6 +533,7 @@ def test_null_time(trino_connection): rows = cur.fetchall() assert rows[0][0] is None + assert_cursor_description(cur, trino_type="time(3)", precision=3) def test_time_with_time_zone_negative_offset(trino_connection): @@ -522,7 +545,7 @@ def test_time_with_time_zone_negative_offset(trino_connection): tz = timezone(-timedelta(hours=8, minutes=0)) assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz) - assert cur.description[0][1] == "time(3) with time zone" + assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3) def test_time_with_time_zone_positive_offset(trino_connection): @@ -534,7 +557,7 @@ def test_time_with_time_zone_positive_offset(trino_connection): tz = timezone(timedelta(hours=8, minutes=0)) assert rows[0][0] == time(1, 2, 3, 456000, tzinfo=tz) - assert cur.description[0][1] == "time(3) with time zone" + assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3) def test_null_date_with_time_zone(trino_connection): @@ -544,6 +567,7 @@ def test_null_date_with_time_zone(trino_connection): rows = cur.fetchall() assert rows[0][0] is None + assert_cursor_description(cur, trino_type="time(3) with time zone", precision=3) @pytest.mark.parametrize( @@ -717,7 +741,7 @@ def test_float_query_param(trino_connection): cur.execute("SELECT ?", params=(1.1,)) rows = cur.fetchall() - assert cur.description[0][1] == "double" + assert_cursor_description(cur, trino_type="double") assert rows[0][0] == 1.1 @@ -726,7 +750,7 @@ def test_float_nan_query_param(trino_connection): cur.execute("SELECT ?", params=(float("nan"),)) rows = cur.fetchall() - assert cur.description[0][1] == "double" + assert_cursor_description(cur, trino_type="double") assert isinstance(rows[0][0], float) assert math.isnan(rows[0][0]) @@ -736,6 +760,7 @@ def test_float_inf_query_param(trino_connection): cur.execute("SELECT ?", params=(float("inf"),)) rows = cur.fetchall() + assert_cursor_description(cur, trino_type="double") assert rows[0][0] == float("inf") cur.execute("SELECT ?", params=(float("-inf"),)) @@ -750,13 +775,13 @@ def test_int_query_param(trino_connection): rows = cur.fetchall() assert rows[0][0] == 3 - assert cur.description[0][1] == "integer" + assert_cursor_description(cur, trino_type="integer") cur.execute("SELECT ?", params=(9223372036854775807,)) rows = cur.fetchall() assert rows[0][0] == 9223372036854775807 - assert cur.description[0][1] == "bigint" + assert_cursor_description(cur, trino_type="bigint") @pytest.mark.parametrize('params', [ @@ -1234,3 +1259,12 @@ def test_describe_table_query(run_trino): aliased=False, ) ] + + +def assert_cursor_description(cur, trino_type, size=None, precision=None, scale=None): + assert cur.description[0][1] == trino_type + assert cur.description[0][2] is None + assert cur.description[0][3] is size + assert cur.description[0][4] is precision + assert cur.description[0][5] is scale + assert cur.description[0][6] is None diff --git a/trino/constants.py b/trino/constants.py index 6813bd28..c9527a3b 100644 --- a/trino/constants.py +++ b/trino/constants.py @@ -53,3 +53,7 @@ HEADER_SET_CATALOG = "X-Trino-Set-Catalog" HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities" + +LENGTH_TYPES = ["char", "varchar"] +PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"] +SCALE_TYPES = ["decimal"] diff --git a/trino/dbapi.py b/trino/dbapi.py index ac8d2893..54744a38 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -28,6 +28,7 @@ import trino.exceptions import trino.logging from trino import constants +from trino.constants import LENGTH_TYPES, PRECISION_TYPES, SCALE_TYPES from trino.exceptions import ( DatabaseError, DataError, @@ -237,6 +238,31 @@ def from_row(cls, row: List[Any]): return cls(*row) +class ColumnDescription(NamedTuple): + name: str + type_code: int + display_size: int + internal_size: int + precision: int + scale: int + null_ok: bool + + @classmethod + def from_column(cls, column: Dict[str, Any]): + type_signature = column["typeSignature"] + raw_type = type_signature["rawType"] + arguments = type_signature["arguments"] + return cls( + column["name"], # name + column["type"], # type_code + None, # display_size + arguments[0]["value"] if raw_type in LENGTH_TYPES else None, # internal_size + arguments[0]["value"] if raw_type in PRECISION_TYPES else None, # precision + arguments[1]["value"] if raw_type in SCALE_TYPES else None, # scale + None # null_ok + ) + + class Cursor(object): """Database cursor. @@ -278,14 +304,13 @@ def update_type(self): return None @property - def description(self): + def description(self) -> List[ColumnDescription]: if self._query.columns is None: return None # [ (name, type_code, display_size, internal_size, precision, scale, null_ok) ] return [ - (col["name"], col["type"], None, None, None, None, None) - for col in self._query.columns + ColumnDescription.from_column(col) for col in self._query.columns ] @property