Skip to content

Commit

Permalink
fix map serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
mayty committed Jan 19, 2025
1 parent e2327e8 commit 8ebf9b3
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
46 changes: 32 additions & 14 deletions clickhouse_driver/util/escape.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import date, datetime, time
from enum import Enum
from functools import wraps
from functools import wraps, partial
from uuid import UUID

from pytz import timezone
Expand Down Expand Up @@ -31,7 +31,7 @@ def escape_datetime(item, context):
else:
format = '%Y-%m-%d %H:%M:%S'

return "'%s'" % item.strftime(format)
return f"'{item.strftime(format)}'"


def maybe_enquote_for_server(f):
Expand All @@ -49,18 +49,18 @@ def wrapper(*args, **kwargs):
if is_str and not isinstance(item, (list, tuple)):
if rv[0] == "'":
if nested:
return "\\'%s\\'" % rv[1:-1]
return f"\\'{rv[1:-1]}\\'"
return rv
if nested:
return "\\'%s\\'" % rv
return "'%s'" % rv
return f"\\'{rv}\\'"
return f"'{rv}'"

if kwargs.get('for_iterable'):
return '%s' % rv
return str(rv)

if nested:
return "\\'%s\\'" % rv
return "'%s'" % rv
return f"\\'{rv!s}\\'"
return f"'{rv!s}'"

return wrapper

Expand All @@ -76,19 +76,19 @@ def escape_param(
return escape_datetime(item, context)

elif isinstance(item, date):
return "'%s'" % item.strftime('%Y-%m-%d')
return f"'{item.strftime('%Y-%m-%d')}'"

elif isinstance(item, time):
return "'%s'" % item.strftime('%H:%M:%S')
return f"'{item.strftime('%H:%M:%S')}'"

elif isinstance(item, str):
# We need double escaping for server-side parameters.
if for_server:
item = ''.join(escape_chars_map.get(c, c) for c in item)
return "'%s'" % ''.join(escape_chars_map.get(c, c) for c in item)
return f"'{''.join(escape_chars_map.get(c, c) for c in item)}'"

elif isinstance(item, list):
return "[%s]" % ', '.join(
serialized_array = ', '.join(
str(
escape_param(
x,
Expand All @@ -99,9 +99,10 @@ def escape_param(
)
) for x in item
)
return f'[{serialized_array}]'

elif isinstance(item, tuple):
return "(%s)" % ', '.join(
serialized_tuple = ', '.join(
str(
escape_param(
x,
Expand All @@ -113,11 +114,28 @@ def escape_param(
) for x in item
)

return f'({serialized_tuple})'

elif isinstance(item, dict):
serializer = partial(
escape_param,
context=context,
for_server=for_server,
for_iterable=True,
nested=True,
)

serialized_dict = ', '.join(
f'{serializer(key)!s}: {serializer(value)!s}'
for key, value in item.items()
)
return f'{{{serialized_dict}}}'

elif isinstance(item, Enum):
return escape_param(item.value, context, for_server=for_server)

elif isinstance(item, UUID):
return "'%s'" % str(item)
return f"'{item!s}'"

else:
return item
Expand Down
6 changes: 5 additions & 1 deletion tests/test_substitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,14 @@ def test_nested_tuple(self):
'(Int32, Tuple(Float64, String), Array(Tuple(String, Int32)))'
)

def test_map(self):
def test_map__int(self):
x = {1: 2, 3: 4}
self._test_type_serialization(x, '^Map$', '(UInt32, UInt32)')

def test_map__string(self):
x = {'1': '34', '2': '45'}
self._test_type_serialization(x, '^Map$', '(String, String)')

@unittest.skip('Duplicate keys not supported')
def test_map__duplicate_keys(self):
pass
Expand Down

0 comments on commit 8ebf9b3

Please sign in to comment.