Skip to content

Commit

Permalink
feat(clickhouse): implement builtin agg functions
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Sep 14, 2023
1 parent 3b27e23 commit eea679a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
16 changes: 14 additions & 2 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,18 @@ def _count_distinct_star(op: ops.CountDistinctStar, **kw: Any) -> str:

@translate_val.register(ops.ScalarUDF)
def _scalar_udf(op, **kw) -> str:
name = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
args = ", ".join(translate_val(arg, **kw) for arg in op.args)
return f"{name}({args})"
return f"{op.__full_name__}({args})"


@translate_val.register(ops.AggUDF)
def _agg_udf(op, **kw) -> str:
args = [arg for name, arg in zip(op.argnames, op.args) if name != "where"]

func = op.__full_name__

if (where := op.where) is not None:
func += "If"
args.append(where)

return f"{func}({', '.join(_sql(translate_val(arg, **kw)) for arg in args)})"
27 changes: 26 additions & 1 deletion ibis/backends/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,33 @@ def arrayJaccardIndex(a: dt.Array[dt.int64], b: dt.Array[dt.int64]) -> float:
array_jaccard_index_no_input_types,
],
)
def test_builtin_udf(con, func):
def test_builtin_scalar_udf(con, func):
expr = func([1, 2], [2, 3])
result = con.execute(expr)
expected = 1.0 / 3.0
assert result == expected


@udf.agg.builtin
def entropy(a) -> float:
...


@udf.agg.builtin(name="sumKahan")
def sum_kahan(a: float) -> float:
...


@pytest.mark.parametrize("func", [entropy, sum_kahan])
def test_builtin_agg_udf(con, func):
t = ibis.memtable(pd.DataFrame({"a": list(map(float, range(10)))}))

expr = func(t.a)
result = con.execute(expr)

table_name = gen_name("agg_udf_test")
expected = con.con.query_df(
f"SELECT {expr.op().__func_name__}(a) FROM {table_name}",
external_data=con._normalize_external_tables({table_name: t.op()}),
).squeeze()
assert result == expected

0 comments on commit eea679a

Please sign in to comment.