Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,3 +295,27 @@ def test_json_column(trino_connection, json_object):
assert rows[0] == (1, json_object)
finally:
metadata.drop_all(engine)


@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_get_table_comment(trino_connection):
engine, conn = trino_connection

if not engine.dialect.has_schema(engine, "test"):
engine.execute(sqla.schema.CreateSchema("test"))
metadata = sqla.MetaData()

try:
sqla.Table(
'table_with_id',
metadata,
sqla.Column('id', sqla.Integer),
schema="test",
# comment="This is a comment" TODO: Support comment creation through sqlalchemy api
)
metadata.create_all(engine)
insp = sqla.inspect(engine)
actual = insp.get_table_comment(table_name='table_with_id', schema="test")
assert actual['text'] is None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a test for some actual comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great, but it seems to not work to add a comment through sqlalchemy. Let's create a bug for that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create an issue + add a TODO. Goal is to make sure if someone other than you works on that issue then they can easily find places in code which need updating otherwise they may end up adding duplicated tests.

finally:
metadata.drop_all(engine)
12 changes: 9 additions & 3 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# limitations under the License.
import json
from textwrap import dedent
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from urllib.parse import unquote_plus

from sqlalchemy import exc, sql
from sqlalchemy.engine import Engine
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
from sqlalchemy.engine.url import URL
Expand Down Expand Up @@ -340,12 +341,17 @@ def _get_server_version_info(self, connection: Connection) -> Any:
logger.debug(f"Failed to get server version: {e.orig.message}")
return None

def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection:
if isinstance(connection, Engine):
return connection.raw_connection()
return connection.connection

def _get_default_catalog_name(self, connection: Connection) -> Optional[str]:
dbapi_connection: trino_dbapi.Connection = connection.connection
dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection)
return dbapi_connection.catalog

def _get_default_schema_name(self, connection: Connection) -> Optional[str]:
dbapi_connection: trino_dbapi.Connection = connection.connection
dbapi_connection: trino_dbapi.Connection = self._raw_connection(connection)
return dbapi_connection.schema

def do_execute(
Expand Down