diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 56a66ae0..033cb61f 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -9,7 +9,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import math import uuid +from decimal import Decimal import pytest import sqlalchemy as sqla @@ -17,7 +19,7 @@ from tests.integration.conftest import trino_version from tests.unit.conftest import sqlalchemy_version -from trino.sqlalchemy.datatype import JSON +from trino.sqlalchemy.datatype import JSON, MAP @pytest.fixture @@ -476,6 +478,56 @@ def test_json_column_operations(trino_connection): metadata.drop_all(engine) +@pytest.mark.skipif( + sqlalchemy_version() < "1.4", + reason="columns argument to select() must be a Python list or other iterable" +) +@pytest.mark.parametrize( + 'trino_connection,map_object,sqla_type', + [ + ('memory', None, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)), + ('memory', {}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.Integer)), + ('memory', {True: False, False: True}, MAP(sqla.sql.sqltypes.Boolean, sqla.sql.sqltypes.Boolean)), + ('memory', {1: 1, 2: None}, MAP(sqla.sql.sqltypes.Integer, sqla.sql.sqltypes.Integer)), + ('memory', {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.Float, sqla.sql.sqltypes.Float)), + ('memory', {1.4: 1.4, math.inf: math.inf}, MAP(sqla.sql.sqltypes.REAL, sqla.sql.sqltypes.REAL)), + ('memory', + {Decimal("1.2"): Decimal("1.2")}, + MAP(sqla.sql.sqltypes.DECIMAL(2, 1), sqla.sql.sqltypes.DECIMAL(2, 1))), + ('memory', {"hello": "world"}, MAP(sqla.sql.sqltypes.String, sqla.sql.sqltypes.String)), + ('memory', {"a ": "a", "null": "n"}, MAP(sqla.sql.sqltypes.CHAR(4), sqla.sql.sqltypes.CHAR(1))), + ('memory', {b'': b'eh?', b'\x00': None}, MAP(sqla.sql.sqltypes.BINARY, sqla.sql.sqltypes.BINARY)), + ], + indirect=['trino_connection'] +) +def test_map_column(trino_connection, map_object, sqla_type): + engine, conn = trino_connection + + if not engine.dialect.has_schema(conn, "test"): + with engine.begin() as connection: + connection.execute(sqla.schema.CreateSchema("test")) + metadata = sqla.MetaData() + + try: + table_with_map = sqla.Table( + 'table_with_map', + metadata, + sqla.Column('id', sqla.Integer), + sqla.Column('map_column', sqla_type), + schema="test" + ) + metadata.create_all(engine) + ins = table_with_map.insert() + conn.execute(ins, {"id": 1, "map_column": map_object}) + query = sqla.select(table_with_map) + result = conn.execute(query) + rows = result.fetchall() + assert len(rows) == 1 + assert rows[0] == (1, map_object) + finally: + metadata.drop_all(engine) + + @pytest.mark.parametrize('trino_connection', ['system'], indirect=True) def test_get_catalog_names(trino_connection): engine, conn = trino_connection diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 6612be64..5f83d984 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -246,6 +246,13 @@ def visit_TIME(self, type_, **kw): def visit_JSON(self, type_, **kw): return 'JSON' + def visit_MAP(self, type_, **kw): + # the key and value types themselves need to be processed otherwise sqltypes.MAP(Float, Float) will get + # rendered as MAP(FLOAT, FLOAT) instead of MAP(REAL, REAL) or MAP(DOUBLE, DOUBLE) + key_type = self.process(type_.key_type, **kw) + value_type = self.process(type_.value_type, **kw) + return f'MAP({key_type}, {value_type})' + class TrinoIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS