diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 3a1d231c..362ced94 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -295,3 +295,27 @@ def test_json_column(trino_connection, json_object): assert rows[0] == (1, json_object) finally: metadata.drop_all(engine) + + +@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True) +def test_get_table_comment(trino_connection): + engine, conn = trino_connection + + if not engine.dialect.has_schema(engine, "test"): + engine.execute(sqla.schema.CreateSchema("test")) + metadata = sqla.MetaData() + + try: + sqla.Table( + 'table_with_id', + metadata, + sqla.Column('id', sqla.Integer), + schema="test", + # comment="This is a comment" TODO: Support comment creation through sqlalchemy api + ) + metadata.create_all(engine) + insp = sqla.inspect(engine) + actual = insp.get_table_comment(table_name='table_with_id', schema="test") + assert actual['text'] is None + finally: + metadata.drop_all(engine) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 0bd7938b..1c19fdf9 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -11,10 +11,11 @@ # limitations under the License. import json from textwrap import dedent -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union from urllib.parse import unquote_plus from sqlalchemy import exc, sql +from sqlalchemy.engine import Engine from sqlalchemy.engine.base import Connection from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext from sqlalchemy.engine.url import URL @@ -340,12 +341,17 @@ def _get_server_version_info(self, connection: Connection) -> Any: logger.debug(f"Failed to get server version: {e.orig.message}") return None + def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection: + if isinstance(connection, Engine): + return connection.raw_connection() + return connection.connection + def _get_default_catalog_name(self, connection: Connection) -> Optional[str]: - dbapi_connection: trino_dbapi.Connection = connection.connection + dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection) return dbapi_connection.catalog def _get_default_schema_name(self, connection: Connection) -> Optional[str]: - dbapi_connection: trino_dbapi.Connection = connection.connection + dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection) return dbapi_connection.schema def do_execute(