Skip to content

Commit 0cb5217

Browse files
authored
refactor: add agg_ops.QcutOp to the sqlglot compiler (#2269)
Fixes internal issue 445774480 🦕
1 parent f4a7206 commit 0cb5217

File tree

5 files changed

+131
-2
lines changed

5 files changed

+131
-2
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,43 @@ def _(
400400
return apply_window_if_present(expr, window)
401401

402402

403+
@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
404+
def _(
405+
op: agg_ops.QcutOp,
406+
column: typed_expr.TypedExpr,
407+
window: typing.Optional[window_spec.WindowSpec] = None,
408+
) -> sge.Expression:
409+
percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False)
410+
percent_ranks = apply_window_if_present(
411+
sge.func("PERCENT_RANK"),
412+
window,
413+
include_framing_clauses=False,
414+
order_by_override=[percent_ranks_order_by],
415+
)
416+
if isinstance(op.quantiles, int):
417+
scaled_rank = percent_ranks * sge.convert(op.quantiles)
418+
# Calculate the 0-based bucket index.
419+
bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1)
420+
safe_bucket_index = sge.func("GREATEST", bucket_index, 0)
421+
422+
return sge.If(
423+
this=sge.Is(this=column.expr, expression=sge.Null()),
424+
true=sge.Null(),
425+
false=sge.Cast(this=safe_bucket_index, to="INT64"),
426+
)
427+
else:
428+
case = sge.Case()
429+
first_quantile = sge.convert(op.quantiles[0])
430+
case = case.when(
431+
sge.LT(this=percent_ranks, expression=first_quantile), sge.Null()
432+
)
433+
for bucket_n in range(len(op.quantiles) - 1):
434+
quantile = sge.convert(op.quantiles[bucket_n + 1])
435+
bucket = sge.convert(bucket_n)
436+
case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket)
437+
return case.else_(sge.Null())
438+
439+
403440
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
404441
def _(
405442
op: agg_ops.QuantileOp,

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
2828
include_framing_clauses: bool = True,
29+
order_by_override: typing.Optional[typing.List[sge.Ordered]] = None,
2930
) -> sge.Expression:
3031
if window is None:
3132
return value
@@ -44,7 +45,11 @@ def apply_window_if_present(
4445
else:
4546
order_by = get_window_order_by(window.ordering)
4647

47-
order = sge.Order(expressions=order_by) if order_by else None
48+
order = None
49+
if order_by_override is not None and len(order_by_override) > 0:
50+
order = sge.Order(expressions=order_by_override)
51+
elif order_by:
52+
order = sge.Order(expressions=order_by)
4853

4954
group_by = (
5055
[

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
637637

638638

639639
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
640-
sqlglot_type = sgt.from_bigframes_dtype(dtype)
640+
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
641+
if sqlglot_type is None:
642+
if value is not None:
643+
raise ValueError("Cannot infer SQLGlot type from None dtype.")
644+
return sge.Null()
645+
641646
if value is None:
642647
return _cast(sge.Null(), sqlglot_type)
643648
elif dtype == dtypes.BYTES_DTYPE:
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`,
4+
`rowindex`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
NOT `int64_col` IS NULL AS `bfcol_4`
10+
FROM `bfcte_0`
11+
), `bfcte_2` AS (
12+
SELECT
13+
*,
14+
IF(
15+
`int64_col` IS NULL,
16+
NULL,
17+
CAST(GREATEST(
18+
CEIL(PERCENT_RANK() OVER (PARTITION BY `bfcol_4` ORDER BY `int64_col` ASC) * 4) - 1,
19+
0
20+
) AS INT64)
21+
) AS `bfcol_5`
22+
FROM `bfcte_1`
23+
), `bfcte_3` AS (
24+
SELECT
25+
*,
26+
IF(`bfcol_4`, `bfcol_5`, NULL) AS `bfcol_6`
27+
FROM `bfcte_2`
28+
), `bfcte_4` AS (
29+
SELECT
30+
*,
31+
NOT `int64_col` IS NULL AS `bfcol_10`
32+
FROM `bfcte_3`
33+
), `bfcte_5` AS (
34+
SELECT
35+
*,
36+
CASE
37+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) < 0
38+
THEN NULL
39+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.25
40+
THEN 0
41+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.5
42+
THEN 1
43+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 0.75
44+
THEN 2
45+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_10` ORDER BY `int64_col` ASC) <= 1
46+
THEN 3
47+
ELSE NULL
48+
END AS `bfcol_11`
49+
FROM `bfcte_4`
50+
), `bfcte_6` AS (
51+
SELECT
52+
*,
53+
IF(`bfcol_10`, `bfcol_11`, NULL) AS `bfcol_12`
54+
FROM `bfcte_5`
55+
)
56+
SELECT
57+
`rowindex`,
58+
`int64_col`,
59+
`bfcol_6` AS `qcut_w_int`,
60+
`bfcol_12` AS `qcut_w_list`
61+
FROM `bfcte_6`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,27 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
434434
snapshot.assert_match(sql_window, "window_out.sql")
435435

436436

437+
def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
438+
if sys.version_info < (3, 12):
439+
pytest.skip(
440+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
441+
)
442+
443+
col_name = "int64_col"
444+
bf = scalar_types_df[[col_name]]
445+
bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop")
446+
447+
q_list = tuple([0, 0.25, 0.5, 0.75, 1])
448+
bf["qcut_w_list"] = bpd.qcut(
449+
scalar_types_df[col_name],
450+
q=q_list,
451+
labels=False,
452+
duplicates="drop",
453+
)
454+
455+
snapshot.assert_match(bf.sql, "out.sql")
456+
457+
437458
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
438459
col_name = "int64_col"
439460
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)