diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 66da0f5e..53fc0af2 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -15,6 +15,7 @@ import sqlalchemy as sqla from sqlalchemy.sql import and_, not_, or_ +from tests.integration.conftest import trino_version from tests.unit.conftest import sqlalchemy_version from trino.sqlalchemy.datatype import JSON @@ -497,3 +498,24 @@ def test_get_view_names_raises(trino_connection): with pytest.raises(sqla.exc.NoSuchTableError): sqla.inspect(engine).get_view_names(None) + + +@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +@pytest.mark.skipif(trino_version() == '351', reason="version() not supported in older Trino versions") +def test_version_is_lazy(trino_connection): + _, conn = trino_connection + result = conn.execute(sqla.text("SELECT 1")) + result.fetchall() + num_queries = _num_queries_containing_string(conn, "SELECT version()") + assert num_queries == 0 + version_info = conn.dialect.server_version_info + assert isinstance(version_info, tuple) + num_queries = _num_queries_containing_string(conn, "SELECT version()") + assert num_queries == 1 + + +def _num_queries_containing_string(connection, query_string): + statement = sqla.text("select query from system.runtime.queries order by query_id desc offset 1 limit 1") + result = connection.execute(statement) + rows = result.fetchall() + return len(list(filter(lambda rec: query_string in rec[0], rows))) diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index 7bc4603b..d5900119 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -336,15 +336,20 @@ def has_sequence(self, connection: Connection, sequence_name: str, schema: str = """Trino has no support for sequence. Returns False indicate that given sequence does not exists.""" return False - def _get_server_version_info(self, connection: Connection) -> Any: - query = "SELECT version()" - try: - res = connection.execute(sql.text(query)) - version = res.scalar() - return tuple([version]) - except exc.ProgrammingError as e: - logger.debug(f"Failed to get server version: {e.orig.message}") - return None + @classmethod + def _get_server_version_info(cls, connection: Connection) -> Any: + def get_server_version_info(_): + query = "SELECT version()" + try: + res = connection.execute(sql.text(query)) + version = res.scalar() + return tuple([version]) + except exc.ProgrammingError as e: + logger.debug(f"Failed to get server version: {e.orig.message}") + return None + + # Make server_version_info lazy in order to only make HTTP calls if user explicitly requests it. + cls.server_version_info = property(get_server_version_info, lambda instance, value: None) def _raw_connection(self, connection: Union[Engine, Connection]) -> trino_dbapi.Connection: if isinstance(connection, Engine):