From 09d1e2f768a935d29aa607f3536fc93e12fc82dc Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sun, 25 Dec 2022 08:34:10 -0500 Subject: [PATCH] feat(bigquery): implement array repeat --- ibis/backends/bigquery/registry.py | 15 ++++++++++++--- ibis/backends/tests/test_array.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/ibis/backends/bigquery/registry.py b/ibis/backends/bigquery/registry.py index 0e016ed004f9..f48d8d93ad72 100644 --- a/ibis/backends/bigquery/registry.py +++ b/ibis/backends/bigquery/registry.py @@ -477,6 +477,17 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: return translate +def _array_repeat(t, op): + start = step = 1 + times = t.translate(op.times) + arg = t.translate(op.arg) + array_length = f"ARRAY_LENGTH({arg})" + stop = f"GREATEST({times}, 0) * {array_length}" + idx = f"COALESCE(NULLIF(MOD(i, {array_length}), 0), {array_length})" + series = f"GENERATE_ARRAY({start}, {stop}, {step})" + return f"ARRAY(SELECT {arg}[SAFE_ORDINAL({idx})] FROM UNNEST({series}) AS i)" + + OPERATION_REGISTRY = { **operation_registry, # Literal @@ -546,12 +557,10 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: ops.ArrayConcat: _array_concat, ops.ArrayIndex: _array_index, ops.ArrayLength: unary("ARRAY_LENGTH"), + ops.ArrayRepeat: _array_repeat, ops.HLLCardinality: reduction("APPROX_COUNT_DISTINCT"), ops.Log: _log, ops.Log2: _log2, - # BigQuery doesn't have these operations built in. - # ops.ArrayRepeat: _array_repeat, - # ops.ArraySlice: _array_slice, ops.Arbitrary: _arbitrary, # Geospatial Columnar ops.GeoUnaryUnion: unary("ST_UNION_AGG"), diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index ab612d9c025c..cb1ac192e2e9 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -46,6 +46,18 @@ def test_array_scalar(con): assert np.array_equal(result, expected) +@pytest.mark.notimpl(["impala", "snowflake", "polars", "datafusion"]) +def test_array_repeat(con): + expr = ibis.array([1.0, 2.0]) * 2 + + result = con.execute(expr.name("tmp")) + expected = np.array([1.0, 2.0, 1.0, 2.0]) + + # This does not check whether `result` is an np.array or a list, + # because it varies across backends and backend configurations + assert np.array_equal(result, expected) + + # Issues #2370 @pytest.mark.notimpl(["impala", "datafusion", "snowflake"]) def test_array_concat(con):