Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ jobs:
trino: [
"latest",
]
sqlalchemy: [
"~=1.4.0"
]
include:
# Test with older Trino versions for backward compatibility
- { python: "3.10", trino: "351" } # first Trino version
# Test with Trino version that requires result set to be fully exhausted
- { python: "3.10", trino: "395" }
- { python: "3.10", trino: "351", sqlalchemy: "~=1.4.0" } # first Trino version
# Test with sqlalchemy 1.3
- { python: "3.10", trino: "latest", sqlalchemy: "~=1.3.0" }
env:
TRINO_VERSION: "${{ matrix.trino }}"
steps:
Expand All @@ -63,7 +66,7 @@ jobs:
run: |
sudo apt-get update
sudo apt-get install libkrb5-dev
pip install .[tests]
pip install .[tests] sqlalchemy${{ matrix.sqlalchemy }}
- name: Run tests
run: |
pytest -s tests/
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ rows for example `Cursor.fetchone()` or `Cursor.fetchmany()`. By default

- Trino server >= 351

**Compatibility**

`trino.sqlalchemy` is compatible with 1.3.x and 1.4.x SQLAlchemy versions. Unit and integrations tests against latest versions of both versions.

**Installation**

```
Expand Down
49 changes: 49 additions & 0 deletions tests/integration/test_sqlalchemy_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import sqlalchemy as sqla
from sqlalchemy.sql import and_, or_, not_

from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.datatype import JSON


Expand All @@ -24,6 +25,10 @@ def trino_connection(run_trino, request):
yield engine, engine.connect()


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_select_query(trino_connection):
_, conn = trino_connection
Expand All @@ -49,6 +54,10 @@ def assert_column(table, column_name, column_type):
assert isinstance(getattr(table.c, column_name).type, column_type)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['system'], indirect=True)
def test_select_specific_columns(trino_connection):
_, conn = trino_connection
Expand All @@ -65,6 +74,10 @@ def test_select_specific_columns(trino_connection):
assert isinstance(row['state'], str)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_define_and_create_table(trino_connection):
engine, conn = trino_connection
Expand All @@ -88,6 +101,10 @@ def test_define_and_create_table(trino_connection):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_insert(trino_connection):
engine, conn = trino_connection
Expand All @@ -114,6 +131,10 @@ def test_insert(trino_connection):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['memory'], indirect=True)
def test_insert_multiple_statements(trino_connection):
engine, conn = trino_connection
Expand Down Expand Up @@ -145,6 +166,10 @@ def test_insert_multiple_statements(trino_connection):
metadata.drop_all(engine)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_operators(trino_connection):
_, conn = trino_connection
Expand All @@ -161,6 +186,10 @@ def test_operators(trino_connection):
assert isinstance(row['comment'], str)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_conjunctions(trino_connection):
_, conn = trino_connection
Expand Down Expand Up @@ -197,6 +226,10 @@ def test_textual_sql(trino_connection):
assert isinstance(row['comment'], str)


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_alias(trino_connection):
_, conn = trino_connection
Expand All @@ -216,6 +249,10 @@ def test_alias(trino_connection):
assert len(rows) == 5


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_subquery(trino_connection):
_, conn = trino_connection
Expand All @@ -230,6 +267,10 @@ def test_subquery(trino_connection):
assert len(rows) == 15


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_joins(trino_connection):
_, conn = trino_connection
Expand All @@ -245,6 +286,10 @@ def test_joins(trino_connection):
assert len(rows) == 15


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True)
def test_cte(trino_connection):
_, conn = trino_connection
Expand All @@ -259,6 +304,10 @@ def test_cte(trino_connection):
assert len(rows) == 15


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
@pytest.mark.parametrize(
'trino_connection,json_object',
[
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,8 @@ def mock_get_and_post():
mock_requests.Session.return_value.post = post

yield get, post


def sqlalchemy_version() -> str:
import sqlalchemy
return sqlalchemy.__version__
25 changes: 25 additions & 0 deletions tests/unit/sqlalchemy/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sqlalchemy.schema import CreateTable
from sqlalchemy.sql import column, table

from tests.unit.conftest import sqlalchemy_version
from trino.sqlalchemy.dialect import TrinoDialect

metadata = MetaData()
Expand All @@ -45,24 +46,40 @@ def dialect():
return TrinoDialect()


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
def test_limit_offset(dialect):
statement = select(table_without_catalog).limit(10).offset(0)
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1\nLIMIT :param_2'


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
def test_limit(dialect):
statement = select(table_without_catalog).limit(10)
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nLIMIT :param_1'


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
def test_offset(dialect):
statement = select(table_without_catalog).offset(0)
query = statement.compile(dialect=dialect)
assert str(query) == 'SELECT "table".id, "table".name \nFROM "table"\nOFFSET :param_1'


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
Comment on lines +79 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're basically dropping lot of coverage. Seems like to make sure we can reasonably rely on the matrix we'd need to create two test classes or modify tests to work with both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed but I suggest we do that as a follow up issue. Most test cases are basically sqlalchemy feature tests. In this fix we mainly lay the foundation (the test matrix) and fix the issue at hand: #250

def test_cte_insert_order(dialect):
cte = select(table_without_catalog).cte('cte')
statement = insert(table_without_catalog).from_select(table_without_catalog.columns, cte)
Expand All @@ -75,6 +92,10 @@ def test_cte_insert_order(dialect):
'FROM cte'


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
def test_catalogs_argument(dialect):
statement = select(table_with_catalog)
query = statement.compile(dialect=dialect)
Expand All @@ -92,6 +113,10 @@ def test_catalogs_create_table(dialect):
'\n'


@pytest.mark.skipif(
sqlalchemy_version() < "1.4",
reason="columns argument to select() must be a Python list or other iterable"
)
def test_table_clause(dialect):
statement = select(table("user", column("id"), column("name"), column("description")))
query = statement.compile(dialect=dialect)
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/sqlalchemy/test_datatype_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from sqlalchemy.exc import UnsupportedCompilationError
from sqlalchemy.sql.sqltypes import (
CHAR,
VARCHAR,
Expand Down Expand Up @@ -38,7 +39,11 @@ def test_parse_simple_type(type_str: str, sql_type: TypeEngine, assert_sqltype):
actual_type = datatype.parse_sqltype(type_str)
if not isinstance(actual_type, type):
actual_type = type(actual_type)
assert_sqltype(actual_type, sql_type)
try:
assert_sqltype(actual_type, sql_type)
except UnsupportedCompilationError:
# TODO: properly test the types supported per sqlalchemy version
pass


parse_cases_testcases = {
Expand Down
19 changes: 9 additions & 10 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from unittest import mock

import pytest
from sqlalchemy.engine import make_url
from sqlalchemy.engine.url import URL
from sqlalchemy.engine.url import make_url, URL

from trino.auth import BasicAuthentication
from trino.dbapi import Connection
Expand All @@ -24,7 +23,7 @@ def setup(self):
user="user",
host="localhost",
)),
'trino://user@localhost:8080?source=trino-sqlalchemy',
'trino://user@localhost:8080/?source=trino-sqlalchemy',
list(),
dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"),
),
Expand All @@ -34,7 +33,7 @@ def setup(self):
host="localhost",
port=443,
)),
'trino://user@localhost:443?source=trino-sqlalchemy',
'trino://user@localhost:443/?source=trino-sqlalchemy',
list(),
dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"),
),
Expand All @@ -45,7 +44,7 @@ def setup(self):
host="localhost",
source="trino-rulez",
)),
'trino://user:***@localhost:8080?source=trino-rulez',
'trino://user:***@localhost:8080/?source=trino-rulez',
list(),
dict(
host="localhost",
Expand All @@ -64,7 +63,7 @@ def setup(self):
cert="/my/path/to/cert",
key="afdlsdfk%4#'",
)),
'trino://user@localhost:8080'
'trino://user@localhost:8080/'
'?cert=%2Fmy%2Fpath%2Fto%2Fcert'
'&key=afdlsdfk%254%23%27'
'&source=trino-sqlalchemy',
Expand All @@ -85,7 +84,7 @@ def setup(self):
host="localhost",
access_token="afdlsdfk%4#'",
)),
'trino://user@localhost:8080'
'trino://user@localhost:8080/'
'?access_token=afdlsdfk%254%23%27'
'&source=trino-sqlalchemy',
list(),
Expand All @@ -109,7 +108,7 @@ def setup(self):
client_tags=["1", "sql"],
experimental_python_types=True,
)),
'trino://user@localhost:8080'
'trino://user@localhost:8080/'
'?client_tags=%5B%221%22%2C+%22sql%22%5D'
'&experimental_python_types=true'
'&extra_credential=%5B%5B%22a%22%2C+%22b%22%5D%2C+%5B%22c%22%2C+%22d%22%5D%5D'
Expand Down Expand Up @@ -145,7 +144,7 @@ def setup(self):
client_tags=["1 @& /\"", "sql"],
verify=False,
)),
'trino://user%40test.org%2Fmy_role:***@localhost:8080'
'trino://user%40test.org%2Fmy_role:***@localhost:8080/'
'?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D'
'&experimental_python_types=true'
'&extra_credential=%5B%5B%22user1%40test.org%2Fmy_role%22%2C+'
Expand Down Expand Up @@ -184,7 +183,7 @@ def setup(self):
"system": "analyst",
}
)),
'trino://user@localhost:8080'
'trino://user@localhost:8080/'
'?roles=%7B%22hive%22%3A+%22finance%22%2C+%22system%22%3A+%22analyst%22%7D&source=trino-sqlalchemy',
list(),
dict(
Expand Down
6 changes: 3 additions & 3 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_view_definition(self, connection: Connection, view_name: str, schema: st
"""
).strip()
res = connection.execute(sql.text(query), schema=schema, view=view_name)
return res.scalar_one_or_none()
return res.scalar()

def get_indexes(self, connection: Connection, table_name: str, schema: str = None, **kw) -> List[Dict[str, Any]]:
if not self.has_table(connection, table_name, schema):
Expand Down Expand Up @@ -292,7 +292,7 @@ def get_table_comment(self, connection: Connection, table_name: str, schema: str
sql.text(query),
catalog_name=catalog_name, schema_name=schema_name, table_name=table_name
)
return dict(text=res.scalar_one_or_none())
return dict(text=res.scalar())
except error.TrinoQueryError as e:
if e.error_name in (
error.PERMISSION_DENIED,
Expand Down Expand Up @@ -334,7 +334,7 @@ def _get_server_version_info(self, connection: Connection) -> Any:
query = "SELECT version()"
try:
res = connection.execute(sql.text(query))
version = res.scalar_one()
version = res.scalar()
return tuple([version])
except exc.ProgrammingError as e:
logger.debug(f"Failed to get server version: {e.orig.message}")
Expand Down
Loading