Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tests/integration/test_dbapi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,25 @@ def test_null_date_with_time_zone(trino_connection):
assert rows[0][0] is None


@pytest.mark.parametrize(
"binary_input",
[
bytearray("a", "utf-8"),
bytearray("a", "ascii"),
bytearray(b'\x00\x00\x00\x00'),
bytearray(4),
bytearray([1, 2, 3]),
],
)
def test_binary_query_param(trino_connection, binary_input):
cur = trino_connection.cursor(experimental_python_types=True)

cur.execute("SELECT ?", params=(binary_input,))
rows = cur.fetchall()

assert rows[0][0] == binary_input


def test_array_query_param(trino_connection):
cur = trino_connection.cursor()

Expand Down
7 changes: 3 additions & 4 deletions tests/integration/test_types_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,9 @@ def test_char(trino_connection):

def test_varbinary(trino_connection):
SqlTest(trino_connection) \
.add_field(sql="X'65683F'", python='ZWg/') \
.add_field(sql="X''", python='') \
.add_field(sql="CAST('' AS VARBINARY)", python='') \
.add_field(sql="from_utf8(CAST('😂😂😂😂😂😂' AS VARBINARY))", python='😂😂😂😂😂😂') \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you removed it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because from_utf8 converts into VARCHAR so it is not relevant here.

.add_field(sql="X'65683F'", python=b'eh?') \
.add_field(sql="X''", python=b'') \
.add_field(sql="CAST('' AS VARBINARY)", python=b'') \
.add_field(sql="CAST(null AS VARBINARY)", python=None) \
.execute()

Expand Down
10 changes: 10 additions & 0 deletions trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from __future__ import annotations

import abc
import base64
import copy
import functools
import os
Expand Down Expand Up @@ -1053,6 +1054,13 @@ def map(self, value) -> Optional[datetime]:
).round_to(self.precision).to_python_type()


class BinaryValueMapper(ValueMapper[bytes]):
def map(self, value) -> Optional[bytes]:
if value is None:
return None
return base64.b64decode(value.encode("utf8"))


class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
def __init__(self, mapper: ValueMapper[Any]):
self.mapper = mapper
Expand Down Expand Up @@ -1138,6 +1146,8 @@ def _create_value_mapper(self, column) -> ValueMapper:
return TimeValueMapper(self._get_precision(column))
elif col_type == 'date':
return DateValueMapper()
elif col_type == 'varbinary':
return BinaryValueMapper()
else:
return NoOpValueMapper()

Expand Down
4 changes: 4 additions & 0 deletions trino/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Fetch methods returns rows as a list of lists on purpose to let the caller
decide to convert then to a list of tuples.
"""
import binascii
import datetime
import math
import uuid
Expand Down Expand Up @@ -398,6 +399,9 @@ def _format_prepared_param(self, param):
if isinstance(param, Decimal):
return "DECIMAL '%s'" % param

if isinstance(param, (bytes, bytearray)):
return "X'%s'" % binascii.hexlify(param).decode("utf-8")

raise trino.exceptions.NotSupportedError("Query parameter of type '%s' is not supported." % type(param))

def _deallocate_prepared_statement(self, statement_name: str) -> None:
Expand Down