diff --git a/bigframes/core/compile/sqlglot/aggregations/op_registration.py b/bigframes/core/compile/sqlglot/aggregations/op_registration.py index eb02b8bd50..a26429f27e 100644 --- a/bigframes/core/compile/sqlglot/aggregations/op_registration.py +++ b/bigframes/core/compile/sqlglot/aggregations/op_registration.py @@ -52,5 +52,5 @@ def arg_checker(*args, **kwargs): def __getitem__(self, op: str | agg_ops.WindowOp) -> CompilationFunc: key = op if isinstance(op, type) else type(op) if str(key) not in self._registered_ops: - raise ValueError(f"{key} is already not registered") + raise ValueError(f"{key} is not registered") return self._registered_ops[str(key)] diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index d157f07df2..d0d887588c 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -239,6 +239,20 @@ def _( return apply_window_if_present(sge.func("MIN", column.expr), window) +@UNARY_OP_REGISTRATION.register(agg_ops.PopVarOp) +def _( + op: agg_ops.PopVarOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=expr, to="INT64") + + expr = sge.func("VAR_POP", expr) + return apply_window_if_present(expr, window) + + @UNARY_OP_REGISTRATION.register(agg_ops.QuantileOp) def _( op: agg_ops.QuantileOp, @@ -278,6 +292,22 @@ def _( return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window) +@UNARY_OP_REGISTRATION.register(agg_ops.StdOp) +def _( + op: agg_ops.StdOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=expr, to="INT64") + + expr = sge.func("STDDEV", expr) + if op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE: + expr = sge.Cast(this=sge.func("FLOOR", expr), to="INT64") + return apply_window_if_present(expr, window) + + @UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp) def _( op: agg_ops.ShiftOp, @@ -331,3 +361,17 @@ def _( expression=shifted, unit=sge.Identifier(this="MICROSECOND"), ) + + +@UNARY_OP_REGISTRATION.register(agg_ops.VarOp) +def _( + op: agg_ops.VarOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + expr = column.expr + if column.dtype == dtypes.BOOL_DTYPE: + expr = sge.Cast(this=expr, to="INT64") + + expr = sge.func("VAR_SAMP", expr) + return apply_window_if_present(expr, window) diff --git a/tests/system/small/engines/test_aggregation.py b/tests/system/small/engines/test_aggregation.py index 3e6d4843de..4ed826d2ae 100644 --- a/tests/system/small/engines/test_aggregation.py +++ b/tests/system/small/engines/test_aggregation.py @@ -111,7 +111,7 @@ def test_engines_unary_aggregates( assert_equivalence_execution(node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) @pytest.mark.parametrize( "op", [agg_ops.std_op, agg_ops.var_op, agg_ops.PopVarOp()], diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_pop_var/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_pop_var/out.sql new file mode 100644 index 0000000000..de422382d1 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_pop_var/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + VAR_POP(`bfcol_1`) AS `bfcol_4`, + VAR_POP(CAST(`bfcol_0` AS INT64)) AS `bfcol_5` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `int64_col`, + `bfcol_5` AS `bool_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_pop_var/window_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_pop_var/window_out.sql new file mode 100644 index 0000000000..fa04dad64e --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_pop_var/window_out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE VAR_POP(`bfcol_0`) OVER () END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_std/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_std/out.sql new file mode 100644 index 0000000000..9bfa6288c3 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_std/out.sql @@ -0,0 +1,27 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1`, + `duration_col` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_1` AS `bfcol_6`, + `bfcol_0` AS `bfcol_7`, + `bfcol_2` AS `bfcol_8` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + STDDEV(`bfcol_6`) AS `bfcol_12`, + STDDEV(CAST(`bfcol_7` AS INT64)) AS `bfcol_13`, + CAST(FLOOR(STDDEV(`bfcol_8`)) AS INT64) AS `bfcol_14`, + CAST(FLOOR(STDDEV(`bfcol_6`)) AS INT64) AS `bfcol_15` + FROM `bfcte_1` +) +SELECT + `bfcol_12` AS `int64_col`, + `bfcol_13` AS `bool_col`, + `bfcol_14` AS `duration_col`, + `bfcol_15` AS `int64_col_w_floor` +FROM `bfcte_2` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_std/window_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_std/window_out.sql new file mode 100644 index 0000000000..e4e4ff0684 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_std/window_out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE STDDEV(`bfcol_0`) OVER () END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_var/out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_var/out.sql new file mode 100644 index 0000000000..59ccd59e8f --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_var/out.sql @@ -0,0 +1,15 @@ +WITH `bfcte_0` AS ( + SELECT + `bool_col` AS `bfcol_0`, + `int64_col` AS `bfcol_1` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + VARIANCE(`bfcol_1`) AS `bfcol_4`, + VARIANCE(CAST(`bfcol_0` AS INT64)) AS `bfcol_5` + FROM `bfcte_0` +) +SELECT + `bfcol_4` AS `int64_col`, + `bfcol_5` AS `bool_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_var/window_out.sql b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_var/window_out.sql new file mode 100644 index 0000000000..a65104215b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/aggregations/snapshots/test_unary_compiler/test_var/window_out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE WHEN `bfcol_0` IS NULL THEN NULL ELSE VARIANCE(`bfcol_0`) OVER () END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `agg_int64` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py index 3d7e4287ac..478368393a 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py @@ -370,6 +370,28 @@ def test_min(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql_window_partition, "window_partition_out.sql") +def test_pop_var(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["int64_col", "bool_col"] + bf_df = scalar_types_df[col_names] + + agg_ops_map = { + "int64_col": agg_ops.PopVarOp().as_expr("int64_col"), + "bool_col": agg_ops.PopVarOp().as_expr("bool_col"), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) + snapshot.assert_match(sql, "out.sql") + + # Window tests + col_name = "int64_col" + bf_df_int = scalar_types_df[[col_name]] + agg_expr = agg_ops.PopVarOp().as_expr(col_name) + window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),)) + sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64") + snapshot.assert_match(sql_window, "window_out.sql") + + def test_quantile(scalar_types_df: bpd.DataFrame, snapshot): col_name = "int64_col" bf_df = scalar_types_df[[col_name]] @@ -428,6 +450,40 @@ def test_shift(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(noop_sql, "noop.sql") +def test_std(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["int64_col", "bool_col", "duration_col"] + bf_df = scalar_types_df[col_names] + bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us") + + # The `to_timedelta` creates a new mapping for the column id. + col_names.insert(0, "rowindex") + name2id = { + col_name: col_id + for col_name, col_id in zip(col_names, bf_df._block.expr.column_ids) + } + + agg_ops_map = { + "int64_col": agg_ops.StdOp().as_expr(name2id["int64_col"]), + "bool_col": agg_ops.StdOp().as_expr(name2id["bool_col"]), + "duration_col": agg_ops.StdOp().as_expr(name2id["duration_col"]), + "int64_col_w_floor": agg_ops.StdOp(should_floor_result=True).as_expr( + name2id["int64_col"] + ), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) + snapshot.assert_match(sql, "out.sql") + + # Window tests + col_name = "int64_col" + bf_df_int = scalar_types_df[[col_name]] + agg_expr = agg_ops.StdOp().as_expr(col_name) + window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),)) + sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64") + snapshot.assert_match(sql_window, "window_out.sql") + + def test_sum(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col"]] agg_ops_map = { @@ -468,3 +524,25 @@ def test_time_series_diff(scalar_types_df: bpd.DataFrame, snapshot): ) sql = _apply_unary_window_op(bf_df, op, window, "diff_time") snapshot.assert_match(sql, "out.sql") + + +def test_var(scalar_types_df: bpd.DataFrame, snapshot): + col_names = ["int64_col", "bool_col"] + bf_df = scalar_types_df[col_names] + + agg_ops_map = { + "int64_col": agg_ops.VarOp().as_expr("int64_col"), + "bool_col": agg_ops.VarOp().as_expr("bool_col"), + } + sql = _apply_unary_agg_ops( + bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys()) + ) + snapshot.assert_match(sql, "out.sql") + + # Window tests + col_name = "int64_col" + bf_df_int = scalar_types_df[[col_name]] + agg_expr = agg_ops.VarOp().as_expr(col_name) + window = window_spec.WindowSpec(ordering=(ordering.descending_over(col_name),)) + sql_window = _apply_unary_window_op(bf_df_int, agg_expr, window, "agg_int64") + snapshot.assert_match(sql_window, "window_out.sql")