Skip to content

Commit

Permalink
feat(clickhouse): add array zip impl
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed May 17, 2023
1 parent 6c00cbc commit efba835
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 41 deletions.
11 changes: 11 additions & 0 deletions ibis/backends/base/sql/alchemy/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ class ArrayType(sat.UserDefinedType):
def __init__(self, value_type: sat.TypeEngine):
self.value_type = sat.to_instance(value_type)

def result_processor(self, dialect, coltype) -> None:
if not coltype.lower().startswith("array"):
return None

inner_processor = (
self.value_type.result_processor(dialect, coltype[len("array(") : -1])
or toolz.identity
)

return lambda v: v if v is None else list(map(inner_processor, v))


@compiles(ArrayType, "default")
def compiles_array(element, compiler, **kw):
Expand Down
12 changes: 12 additions & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import calendar
import contextlib
import functools
from functools import partial
from operator import add, mul, sub
Expand Down Expand Up @@ -1386,3 +1387,14 @@ def _array_remove(op, **kw):
@translate_val.register(ops.ArrayUnion)
def _array_union(op, **kw):
return translate_val(ops.ArrayDistinct(ops.ArrayConcat(op.left, op.right)), **kw)


@translate_val.register(ops.Zip)
def _array_zip(op: ops.Zip, **kw: Any) -> str:
arglist = []
for arg in op.arg:
sql_arg = translate_val(arg, **kw)
with contextlib.suppress(AttributeError):
sql_arg = sql_arg.sql(dialect="clickhouse")
arglist.append(sql_arg)
return f"arrayZip({', '.join(arglist)})"
1 change: 1 addition & 0 deletions ibis/backends/datafusion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TestConf(BackendTest, RoundAwayFromZero):
native_bool = False
supports_structs = False
supports_json = False
supports_arrays = False

@staticmethod
def connect(data_directory: Path):
Expand Down
42 changes: 30 additions & 12 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,17 +704,35 @@ def test_unnest_struct(con):
tm.assert_series_equal(result, expected)


def test_zip6(con):
t = con.tables.array_types

res = t.x.zip(t.z, t.z, t.z, t.z, t.z, t.z)
df = res.execute()
assert not df.empty

@pytest.mark.never(
["impala", "mssql", "mysql", "sqlite"],
raises=com.OperationNotDefinedError,
reason="no array support",
)
@pytest.mark.notyet(["bigquery"], raises=com.OperationNotDefinedError)
@pytest.mark.notimpl(
[
"dask",
"datafusion",
"druid",
"duckdb",
"oracle",
"pandas",
"polars",
"postgres",
"snowflake",
],
raises=com.OperationNotDefinedError,
)
def test_zip(backend):
t = backend.array_types

def test_zip2(con):
t = con.tables.array_types
x = t.x.execute()
res = t.x.zip(t.x)
s = res.execute()
assert len(x[0]) == len(s[0])

res = t.x.zip(t.z)
df = res.execute()
assert not df.empty
x = t.x.execute()
res = t.x.zip(t.x, t.x, t.x, t.x, t.x, t.x)
s = res.execute()
assert len(x[0]) == len(s[0])
16 changes: 1 addition & 15 deletions ibis/backends/trino/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,24 +85,10 @@ def parse(text: str, default_decimal_parameters=(18, 3)) -> dt.DataType:
parsy.seq(LPAREN.then(ty).skip(COMMA), ty.skip(RPAREN)).combine(dt.Map)
)

def make_tuple(pairs):
result = []
for i, typ in enumerate(pairs, start=1):
if len(typ) == 2:
name, typ = typ
else:
name = f"f{i:d}"
result.append((name, typ))
return dt.Struct.from_tuples(result)

struct = (
spaceless_string("row")
.then(LPAREN)
.then(
parsy.seq(parsy.alt(parsy.seq(spaceless(FIELD), ty), ty))
.sep_by(COMMA)
.map(make_tuple)
)
.then(parsy.seq(spaceless(FIELD), ty).sep_by(COMMA).map(dt.Struct.from_tuples))
.skip(RPAREN)
)

Expand Down
25 changes: 11 additions & 14 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,20 +233,18 @@ def _first_last(t, op, *, offset: Literal[-1, 1]):
return sa.func.element_at(t._reduction(sa.func.array_agg, op), offset)


_MAX_ZIP_ARGUMENTS = 5


def _zip(t, op):
# more than one chunk means more than 5 arguments to zip, which trino
# doesn't support
#
# help trino out by reducing in chunks of 5 using zip_with
max_zip_arguments = 5
chunks = (
(len(chunk), sa.func.zip(*chunk) if len(chunk) > 1 else chunk[0])
for chunk in toolz.partition_all(_MAX_ZIP_ARGUMENTS, map(t.translate, op.arg))
for chunk in toolz.partition_all(max_zip_arguments, map(t.translate, op.arg))
)

# more than one chunk means more than 5 arguments to zip, which trino
# doesn't support
#
# help trino out by reducing chunks of 5
def reducer(left, right):
def combine_zipped(left, right):
left_n, left_chunk = left
lhs = (
", ".join(f"x[{i:d}]" for i in range(1, left_n + 1)) if left_n > 1 else "x"
Expand All @@ -259,13 +257,12 @@ def reducer(left, right):
else "y"
)

row_str = f"ROW({lhs}, {rhs})"
chunk = sa.func.zip_with(
left_chunk, right_chunk, sa.literal_column(f"(x, y) -> {row_str}")
zipped_chunk = sa.func.zip_with(
left_chunk, right_chunk, sa.literal_column(f"(x, y) -> ROW({lhs}, {rhs})")
)
return left_n + right_n, chunk
return left_n + right_n, zipped_chunk

all_n, chunk = reduce(reducer, chunks)
all_n, chunk = reduce(combine_zipped, chunks)

dtype = op.output_dtype

Expand Down

0 comments on commit efba835

Please sign in to comment.