diff --git a/tests/integration/test_dbapi_integration.py b/tests/integration/test_dbapi_integration.py index a533d671..760d800e 100644 --- a/tests/integration/test_dbapi_integration.py +++ b/tests/integration/test_dbapi_integration.py @@ -532,6 +532,25 @@ def test_null_date_with_time_zone(trino_connection): assert rows[0][0] is None +@pytest.mark.parametrize( + "binary_input", + [ + bytearray("a", "utf-8"), + bytearray("a", "ascii"), + bytearray(b'\x00\x00\x00\x00'), + bytearray(4), + bytearray([1, 2, 3]), + ], +) +def test_binary_query_param(trino_connection, binary_input): + cur = trino_connection.cursor(experimental_python_types=True) + + cur.execute("SELECT ?", params=(binary_input,)) + rows = cur.fetchall() + + assert rows[0][0] == binary_input + + def test_array_query_param(trino_connection): cur = trino_connection.cursor() diff --git a/tests/integration/test_types_integration.py b/tests/integration/test_types_integration.py index c5e38efa..37201a1d 100644 --- a/tests/integration/test_types_integration.py +++ b/tests/integration/test_types_integration.py @@ -121,10 +121,9 @@ def test_char(trino_connection): def test_varbinary(trino_connection): SqlTest(trino_connection) \ - .add_field(sql="X'65683F'", python='ZWg/') \ - .add_field(sql="X''", python='') \ - .add_field(sql="CAST('' AS VARBINARY)", python='') \ - .add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \ + .add_field(sql="X'65683F'", python=b'eh?') \ + .add_field(sql="X''", python=b'') \ + .add_field(sql="CAST('' AS VARBINARY)", python=b'') \ .add_field(sql="CAST(null AS VARBINARY)", python=None) \ .execute() diff --git a/trino/client.py b/trino/client.py index f74539b0..8bd4baf7 100644 --- a/trino/client.py +++ b/trino/client.py @@ -35,6 +35,7 @@ from __future__ import annotations import abc +import base64 import copy import functools import os @@ -1053,6 +1054,13 @@ def map(self, value) -> Optional[datetime]: ).round_to(self.precision).to_python_type() +class BinaryValueMapper(ValueMapper[bytes]): + def map(self, value) -> Optional[bytes]: + if value is None: + return None + return base64.b64decode(value.encode("utf8")) + + class ArrayValueMapper(ValueMapper[List[Optional[Any]]]): def __init__(self, mapper: ValueMapper[Any]): self.mapper = mapper @@ -1138,6 +1146,8 @@ def _create_value_mapper(self, column) -> ValueMapper: return TimeValueMapper(self._get_precision(column)) elif col_type == 'date': return DateValueMapper() + elif col_type == 'varbinary': + return BinaryValueMapper() else: return NoOpValueMapper() diff --git a/trino/dbapi.py b/trino/dbapi.py index 7a734df5..66ba2e1f 100644 --- a/trino/dbapi.py +++ b/trino/dbapi.py @@ -17,6 +17,7 @@ Fetch methods returns rows as a list of lists on purpose to let the caller decide to convert then to a list of tuples. """ +import binascii import datetime import math import uuid @@ -398,6 +399,9 @@ def _format_prepared_param(self, param): if isinstance(param, Decimal): return "DECIMAL '%s'" % param + if isinstance(param, (bytes, bytearray)): + return "X'%s'" % binascii.hexlify(param).decode("utf-8") + raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param)) def _deallocate_prepared_statement(self, statement_name: str) -> None: