Skip to content

Commit 920f673

Browse files
committed
refactor: add agg_ops.QcutOp to the sqlglot compiler
1 parent f7fd2d2 commit 920f673

File tree

6 files changed

+121
-8
lines changed

6 files changed

+121
-8
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,43 @@ def _(
400400
return apply_window_if_present(expr, window)
401401

402402

403+
@UNARY_OP_REGISTRATION.register(agg_ops.QcutOp)
404+
def _(
405+
op: agg_ops.QcutOp,
406+
column: typed_expr.TypedExpr,
407+
window: typing.Optional[window_spec.WindowSpec] = None,
408+
) -> sge.Expression:
409+
percent_ranks_order_by = sge.Ordered(this=column.expr, desc=False)
410+
percent_ranks = apply_window_if_present(
411+
sge.func("PERCENT_RANK"),
412+
window,
413+
include_framing_clauses=False,
414+
order_by_override=[percent_ranks_order_by],
415+
)
416+
if isinstance(op.quantiles, int):
417+
scaled_rank = percent_ranks * sge.convert(op.quantiles)
418+
# Calculate the 0-based bucket index.
419+
bucket_index = sge.func("CEIL", scaled_rank) - sge.convert(1)
420+
safe_bucket_index = sge.func("GREATEST", bucket_index, 0)
421+
422+
return sge.If(
423+
this=sge.Is(this=column.expr, expression=sge.Null()),
424+
true=sge.Null(),
425+
false=sge.Cast(this=safe_bucket_index, to="INT64"),
426+
)
427+
else:
428+
case = sge.Case()
429+
first_quantile = sge.convert(op.quantiles[0])
430+
case = case.when(
431+
sge.LT(this=percent_ranks, expression=first_quantile), sge.Null()
432+
)
433+
for bucket_n in range(len(op.quantiles) - 1):
434+
quantile = sge.convert(op.quantiles[bucket_n + 1])
435+
bucket = sge.convert(bucket_n)
436+
case = case.when(sge.LTE(this=percent_ranks, expression=quantile), bucket)
437+
return case.else_(sge.Null())
438+
439+
403440
@UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp)
404441
def _(
405442
op: agg_ops.QuantileOp,

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
2828
include_framing_clauses: bool = True,
29+
order_by_override: typing.Optional[typing.List[sge.Ordered]] = None,
2930
) -> sge.Expression:
3031
if window is None:
3132
return value
@@ -44,7 +45,11 @@ def apply_window_if_present(
4445
else:
4546
order_by = get_window_order_by(window.ordering)
4647

47-
order = sge.Order(expressions=order_by) if order_by else None
48+
order = None
49+
if order_by_override is not None and len(order_by_override) > 0:
50+
order = sge.Order(expressions=order_by_override)
51+
elif order_by:
52+
order = sge.Order(expressions=order_by)
4853

4954
group_by = (
5055
[

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,12 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
637637

638638

639639
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
640-
sqlglot_type = sgt.from_bigframes_dtype(dtype)
640+
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
641+
if sqlglot_type is None:
642+
if value is not None:
643+
raise ValueError("Cannot infer SQLGlot type from None dtype.")
644+
return sge.Null()
645+
641646
if value is None:
642647
return _cast(sge.Null(), sqlglot_type)
643648
elif dtype == dtypes.BYTES_DTYPE:

bigframes/core/reshape/tile.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import pandas as pd
2323

2424
import bigframes
25-
import bigframes.constants
26-
import bigframes.core.expression as ex
2725
import bigframes.core.ordering as order
2826
import bigframes.core.utils as utils
2927
import bigframes.core.window_spec as window_specs
@@ -165,7 +163,6 @@ def qcut(
165163
f"Only duplicates='drop' is supported in BigQuery DataFrames so far. {constants.FEEDBACK_LINK}"
166164
)
167165
block = x._block
168-
label = block.col_id_to_label[x._value_column]
169166
block, nullity_id = block.apply_unary_op(x._value_column, ops.notnull_op)
170167
block, result = block.apply_window_op(
171168
x._value_column,
@@ -175,9 +172,6 @@ def qcut(
175172
ordering=(order.ascending_over(x._value_column),),
176173
),
177174
)
178-
block, result = block.project_expr(
179-
ops.where_op.as_expr(result, nullity_id, ex.const(None)), label=label
180-
)
181175
return bigframes.series.Series(block.select_column(result))
182176

183177

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col`,
4+
`rowindex`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
NOT `int64_col` IS NULL AS `bfcol_4`
10+
FROM `bfcte_0`
11+
), `bfcte_2` AS (
12+
SELECT
13+
*,
14+
IF(
15+
`int64_col` IS NULL,
16+
NULL,
17+
CAST(GREATEST(
18+
CEIL(PERCENT_RANK() OVER (PARTITION BY `bfcol_4` ORDER BY `int64_col` ASC) * 4) - 1,
19+
0
20+
) AS INT64)
21+
) AS `bfcol_5`
22+
FROM `bfcte_1`
23+
), `bfcte_3` AS (
24+
SELECT
25+
*,
26+
NOT `int64_col` IS NULL AS `bfcol_9`
27+
FROM `bfcte_2`
28+
), `bfcte_4` AS (
29+
SELECT
30+
*,
31+
CASE
32+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) < 0
33+
THEN NULL
34+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.25
35+
THEN 0
36+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.5
37+
THEN 1
38+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 0.75
39+
THEN 2
40+
WHEN PERCENT_RANK() OVER (PARTITION BY `bfcol_9` ORDER BY `int64_col` ASC) <= 1
41+
THEN 3
42+
ELSE NULL
43+
END AS `bfcol_10`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`rowindex`,
48+
`int64_col`,
49+
`bfcol_5` AS `qcut_w_int`,
50+
`bfcol_10` AS `qcut_w_list`
51+
FROM `bfcte_4`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,27 @@ def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot):
435435
snapshot.assert_match(sql_window, "window_out.sql")
436436

437437

438+
def test_qcut(scalar_types_df: bpd.DataFrame, snapshot):
439+
if sys.version_info < (3, 12):
440+
pytest.skip(
441+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
442+
)
443+
444+
col_name = "int64_col"
445+
bf = scalar_types_df[[col_name]]
446+
bf["qcut_w_int"] = bpd.qcut(bf[col_name], q=4, labels=False, duplicates="drop")
447+
448+
q_list = tuple([0, 0.25, 0.5, 0.75, 1])
449+
bf["qcut_w_list"] = bpd.qcut(
450+
scalar_types_df[col_name],
451+
q=q_list,
452+
labels=False,
453+
duplicates="drop",
454+
)
455+
456+
snapshot.assert_match(bf.sql, "out.sql")
457+
458+
438459
def test_quantile(scalar_types_df: bpd.DataFrame, snapshot):
439460
col_name = "int64_col"
440461
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)