diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index afc6c713..1186c671 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -21,6 +21,7 @@ import trino from tests.integration.conftest import trino_version from trino import constants +from trino.dbapi import DescribeOutput from trino.exceptions import NotSupportedError, TrinoQueryError, TrinoUserError from trino.transaction import IsolationLevel @@ -1136,3 +1137,69 @@ def test_connection_without_timezone(run_trino): assert session_tz == localzone or \ (session_tz == "UTC" and localzone == "Etc/UTC") \ # Workaround for difference between Trino timezone and tzlocal for UTC + + +def test_describe(run_trino): + _, host, port = run_trino + + trino_connection = trino.dbapi.Connection( + host=host, port=port, user="test", catalog="tpch", + ) + cur = trino_connection.cursor() + + result = cur.describe("SELECT 1, DECIMAL '1.0' as a") + + assert result == [ + DescribeOutput(name='_col0', catalog='', schema='', table='', type='integer', type_size=4, aliased=False), + DescribeOutput(name='a', catalog='', schema='', table='', type='decimal(2,1)', type_size=8, aliased=True) + ] + + +def test_describe_table_query(run_trino): + _, host, port = run_trino + + trino_connection = trino.dbapi.Connection( + host=host, port=port, user="test", catalog="tpch", + ) + cur = trino_connection.cursor() + + result = cur.describe("SELECT * from tpch.tiny.nation") + + assert result == [ + DescribeOutput( + name='nationkey', + catalog='tpch', + schema='tiny', + table='nation', + type='bigint', + type_size=8, + aliased=False, + ), + DescribeOutput( + name='name', + catalog='tpch', + schema='tiny', + table='nation', + type='varchar(25)', + type_size=0, + aliased=False, + ), + DescribeOutput( + name='regionkey', + catalog='tpch', + schema='tiny', + table='nation', + type='bigint', + type_size=8, + aliased=False, + ), + DescribeOutput( + name='comment', + catalog='tpch', + schema='tiny', + table='nation', + type='varchar(152)', + type_size=0, + aliased=False, + ) + ] diff --git a/trino/dbapi.py b/trino/dbapi.py index f50f5291..1d75235a 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -21,7 +21,7 @@ import math import uuid from decimal import Decimal -from typing import Any, Dict, List, Optional # NOQA for mypy types +from typing import Any, Dict, List, NamedTuple, Optional # NOQA for mypy types import trino.client import trino.exceptions @@ -222,6 +222,20 @@ def cursor(self, experimental_python_types: bool = None): ) +class DescribeOutput(NamedTuple): + name: str + catalog: str + schema: str + table: str + type: str + type_size: int + aliased: bool + + @classmethod + def from_row(cls, row: List[Any]): + return cls(*row) + + class Cursor(object): """Database cursor. @@ -519,6 +533,28 @@ def fetchmany(self, size=None) -> List[List[Any]]: return result + def describe(self, sql: str) -> List[DescribeOutput]: + """ + List the output columns of a SQL statement, including the column name (or alias), catalog, schema, table, type, + type size in bytes, and a boolean indicating if the column is aliased. + + :param sql: SQL statement + """ + statement_name = self._generate_unique_statement_name() + self._prepare_statement(sql, statement_name) + try: + sql = f"DESCRIBE OUTPUT {statement_name}" + self._query = trino.client.TrinoQuery( + self._request, + sql=sql, + experimental_python_types=self._experimental_pyton_types, + ) + result = self._query.execute() + finally: + self._deallocate_prepared_statement(statement_name) + + return list(map(lambda x: DescribeOutput.from_row(x), result)) + def genall(self): return self._query.result