Skip to content

Commit 482bd0e

Browse files
authored
refactor: add agg_ops.ProductOp and NuniqueOp to the sqlglot compiler (#2289)
Fixes internal issue 445774480 🦕
1 parent ed79146 commit 482bd0e

File tree

5 files changed

+146
-0
lines changed

5 files changed

+146
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,17 @@ def _(
386386
return apply_window_if_present(sge.func("MIN", column.expr), window)
387387

388388

389+
@UNARY_OP_REGISTRATION.register(agg_ops.NuniqueOp)
390+
def _(
391+
op: agg_ops.NuniqueOp,
392+
column: typed_expr.TypedExpr,
393+
window: typing.Optional[window_spec.WindowSpec] = None,
394+
) -> sge.Expression:
395+
return apply_window_if_present(
396+
sge.func("COUNT", sge.Distinct(expressions=[column.expr])), window
397+
)
398+
399+
389400
@UNARY_OP_REGISTRATION.register(agg_ops.PopVarOp)
390401
def _(
391402
op: agg_ops.PopVarOp,
@@ -400,6 +411,58 @@ def _(
400411
return apply_window_if_present(expr, window)
401412

402413

414+
@UNARY_OP_REGISTRATION.register(agg_ops.ProductOp)
415+
def _(
416+
op: agg_ops.ProductOp,
417+
column: typed_expr.TypedExpr,
418+
window: typing.Optional[window_spec.WindowSpec] = None,
419+
) -> sge.Expression:
420+
# Need to short-circuit as log with zeroes is illegal sql
421+
is_zero = sge.EQ(this=column.expr, expression=sge.convert(0))
422+
423+
# There is no product sql aggregate function, so must implement as a sum of logs, and then
424+
# apply power after. Note, log and power base must be equal! This impl uses natural log.
425+
logs = (
426+
sge.Case()
427+
.when(is_zero, sge.convert(0))
428+
.else_(sge.func("LN", sge.func("ABS", column.expr)))
429+
)
430+
logs_sum = apply_window_if_present(sge.func("SUM", logs), window)
431+
magnitude = sge.func("EXP", logs_sum)
432+
433+
# Can't determine sign from logs, so have to determine parity of count of negative inputs
434+
is_negative = (
435+
sge.Case()
436+
.when(
437+
sge.LT(this=sge.func("SIGN", column.expr), expression=sge.convert(0)),
438+
sge.convert(1),
439+
)
440+
.else_(sge.convert(0))
441+
)
442+
negative_count = apply_window_if_present(sge.func("SUM", is_negative), window)
443+
negative_count_parity = sge.Mod(
444+
this=negative_count, expression=sge.convert(2)
445+
) # 1 if result should be negative, otherwise 0
446+
447+
any_zeroes = apply_window_if_present(sge.func("LOGICAL_OR", is_zero), window)
448+
449+
float_result = (
450+
sge.Case()
451+
.when(any_zeroes, sge.convert(0))
452+
.else_(
453+
sge.Mul(
454+
this=magnitude,
455+
expression=sge.If(
456+
this=sge.EQ(this=negative_count_parity, expression=sge.convert(1)),
457+
true=sge.convert(-1),
458+
false=sge.convert(1),
459+
),
460+
)
461+
)
462+
)
463+
return float_result
464+
465+
403466
@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
404467
def _(
405468
op: agg_ops.QcutOp,
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
COUNT(DISTINCT `int64_col`) AS `bfcol_1`
8+
FROM `bfcte_0`
9+
)
10+
SELECT
11+
`bfcol_1` AS `int64_col`
12+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
CASE
8+
WHEN LOGICAL_OR(`int64_col` = 0)
9+
THEN 0
10+
ELSE EXP(SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END)) * IF(MOD(SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END), 2) = 1, -1, 1)
11+
END AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `int64_col`
16+
FROM `bfcte_1`
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`,
4+
`string_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CASE
10+
WHEN LOGICAL_OR(`int64_col` = 0) OVER (PARTITION BY `string_col`)
11+
THEN 0
12+
ELSE EXP(
13+
SUM(CASE WHEN `int64_col` = 0 THEN 0 ELSE LN(ABS(`int64_col`)) END) OVER (PARTITION BY `string_col`)
14+
) * IF(
15+
MOD(
16+
SUM(CASE WHEN SIGN(`int64_col`) < 0 THEN 1 ELSE 0 END) OVER (PARTITION BY `string_col`),
17+
2
18+
) = 1,
19+
-1,
20+
1
21+
)
22+
END AS `bfcol_2`
23+
FROM `bfcte_0`
24+
)
25+
SELECT
26+
`bfcol_2` AS `agg_int64`
27+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,15 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot):
412412
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
413413

414414

415+
def test_nunique(scalar_types_df: bpd.DataFrame, snapshot):
416+
col_name = "int64_col"
417+
bf_df = scalar_types_df[[col_name]]
418+
agg_expr = agg_ops.NuniqueOp().as_expr(col_name)
419+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
420+
421+
snapshot.assert_match(sql, "out.sql")
422+
423+
415424
def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
416425
col_names = ["int64_col", "bool_col"]
417426
bf_df = scalar_types_df[col_names]
@@ -434,6 +443,25 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
434443
snapshot.assert_match(sql_window, "window_out.sql")
435444

436445

446+
def test_product(scalar_types_df: bpd.DataFrame, snapshot):
447+
col_name = "int64_col"
448+
bf_df = scalar_types_df[[col_name]]
449+
agg_expr = agg_ops.ProductOp().as_expr(col_name)
450+
sql = _apply_unary_agg_ops(bf_df, [agg_expr], [col_name])
451+
452+
snapshot.assert_match(sql, "out.sql")
453+
454+
bf_df_str = scalar_types_df[[col_name, "string_col"]]
455+
window_partition = window_spec.WindowSpec(
456+
grouping_keys=(expression.deref("string_col"),),
457+
)
458+
sql_window_partition = _apply_unary_window_op(
459+
bf_df_str, agg_expr, window_partition, "agg_int64"
460+
)
461+
462+
snapshot.assert_match(sql_window_partition, "window_partition_out.sql")
463+
464+
437465
def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
438466
if sys.version_info < (3, 12):
439467
pytest.skip(

0 commit comments

Comments
 (0)