From 669ff82cb766169d50228951d719d90068c26184 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 7 Nov 2025 22:51:18 +0000 Subject: [PATCH 1/2] refactor: make sqlglot compile_sql a top-level function --- bigframes/core/compile/sqlglot/__init__.py | 4 +- bigframes/core/compile/sqlglot/compiler.py | 645 +++++++++--------- bigframes/session/direct_gbq_execution.py | 4 +- bigframes/testing/compiler_session.py | 10 +- .../sqlglot/test_compile_random_sample.py | 4 +- 5 files changed, 320 insertions(+), 347 deletions(-) diff --git a/bigframes/core/compile/sqlglot/__init__.py b/bigframes/core/compile/sqlglot/__init__.py index 4ceb4118cd..9e3f123807 100644 --- a/bigframes/core/compile/sqlglot/__init__.py +++ b/bigframes/core/compile/sqlglot/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler +from bigframes.core.compile.sqlglot.compiler import compile_sql import bigframes.core.compile.sqlglot.expressions.ai_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.array_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.blob_ops # noqa: F401 @@ -29,4 +29,4 @@ import bigframes.core.compile.sqlglot.expressions.struct_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.timedelta_ops # noqa: F401 -__all__ = ["SQLGlotCompiler"] +__all__ = ["compile_sql"] diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 47ad8db21b..7dc8d4bec0 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -17,7 +17,6 @@ import functools import typing -from google.cloud import bigquery import sqlglot.expressions as sge from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite @@ -31,371 +30,347 @@ from bigframes.core.rewrite import schema_binding -class SQLGlotCompiler: - """Compiles BigFrame nodes into SQL using SQLGlot.""" - - uid_gen: guid.SequentialUIDGenerator - """Generator for unique identifiers.""" - - def __init__(self): - self.uid_gen = guid.SequentialUIDGenerator() - - def compile( - self, - node: nodes.BigFrameNode, - *, - ordered: bool = True, - limit: typing.Optional[int] = None, - ) -> str: - """Compiles node into sql where rows are sorted with ORDER BY.""" - request = configs.CompileRequest(node, sort_rows=ordered, peek_count=limit) - return self._compile_sql(request).sql - - def compile_raw( - self, - node: nodes.BigFrameNode, - ) -> typing.Tuple[ - str, typing.Sequence[bigquery.SchemaField], bf_ordering.RowOrdering - ]: - """Compiles node into sql that exposes all columns, including hidden - ordering-only columns.""" - request = configs.CompileRequest( - node, sort_rows=False, materialize_all_order_keys=True - ) - result = self._compile_sql(request) - assert result.row_order is not None - return result.sql, result.sql_schema, result.row_order - - def _compile_sql(self, request: configs.CompileRequest) -> configs.CompileResult: - output_names = tuple( - (expression.DerefOp(id), id.sql) for id in request.node.ids - ) - result_node = nodes.ResultNode( - request.node, - output_cols=output_names, - limit=request.peek_count, - ) - if request.sort_rows: - # Can only pullup slice if we are doing ORDER BY in outermost SELECT - # Need to do this before replacing unsupported ops, as that will rewrite slice ops - result_node = rewrite.pull_up_limits(result_node) - result_node = _replace_unsupported_ops(result_node) - # prune before pulling up order to avoid unnnecessary row_number() ops - result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) - result_node = rewrite.defer_order( - result_node, output_hidden_row_keys=request.materialize_all_order_keys - ) - if request.sort_rows: - result_node = typing.cast( - nodes.ResultNode, rewrite.column_pruning(result_node) - ) - result_node = self._remap_variables(result_node) - result_node = typing.cast( - nodes.ResultNode, rewrite.defer_selection(result_node) - ) - sql = self._compile_result_node(result_node) - return configs.CompileResult( - sql, result_node.schema.to_bigquery(), result_node.order_by - ) - - ordering: typing.Optional[bf_ordering.RowOrdering] = result_node.order_by - result_node = dataclasses.replace(result_node, order_by=None) +def compile_sql(request: configs.CompileRequest) -> configs.CompileResult: + """Compiles a BigFrameNode according to the request into SQL using SQLGlot.""" + + # Generator for unique identifiers. + uid_gen = guid.SequentialUIDGenerator() + output_names = tuple((expression.DerefOp(id), id.sql) for id in request.node.ids) + result_node = nodes.ResultNode( + request.node, + output_cols=output_names, + limit=request.peek_count, + ) + if request.sort_rows: + # Can only pullup slice if we are doing ORDER BY in outermost SELECT + # Need to do this before replacing unsupported ops, as that will rewrite slice ops + result_node = rewrite.pull_up_limits(result_node) + result_node = _replace_unsupported_ops(result_node) + # prune before pulling up order to avoid unnnecessary row_number() ops + result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) + result_node = rewrite.defer_order( + result_node, output_hidden_row_keys=request.materialize_all_order_keys + ) + if request.sort_rows: result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) - - result_node = self._remap_variables(result_node) + result_node = _remap_variables(result_node, uid_gen) result_node = typing.cast( nodes.ResultNode, rewrite.defer_selection(result_node) ) - sql = self._compile_result_node(result_node) - # Return the ordering iff no extra columns are needed to define the row order - if ordering is not None: - output_order = ( - ordering - if ordering.referenced_columns.issubset(result_node.ids) - else None - ) - assert (not request.materialize_all_order_keys) or (output_order is not None) + sql = _compile_result_node(result_node, uid_gen) return configs.CompileResult( - sql, result_node.schema.to_bigquery(), output_order + sql, result_node.schema.to_bigquery(), result_node.order_by ) - def _remap_variables(self, node: nodes.ResultNode) -> nodes.ResultNode: - """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs.""" - - result_node, _ = rewrite.remap_variables( - node, map(identifiers.ColumnId, self.uid_gen.get_uid_stream("bfcol_")) + ordering: typing.Optional[bf_ordering.RowOrdering] = result_node.order_by + result_node = dataclasses.replace(result_node, order_by=None) + result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node)) + + result_node = _remap_variables(result_node, uid_gen) + result_node = typing.cast(nodes.ResultNode, rewrite.defer_selection(result_node)) + sql = _compile_result_node(result_node, uid_gen) + # Return the ordering iff no extra columns are needed to define the row order + if ordering is not None: + output_order = ( + ordering if ordering.referenced_columns.issubset(result_node.ids) else None ) - return typing.cast(nodes.ResultNode, result_node) - - def _compile_result_node(self, root: nodes.ResultNode) -> str: - # Have to bind schema as the final step before compilation. - root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root)) - selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( - (name, scalar_compiler.scalar_op_compiler.compile_expression(ref)) - for ref, name in root.output_cols + assert (not request.materialize_all_order_keys) or (output_order is not None) + return configs.CompileResult(sql, result_node.schema.to_bigquery(), output_order) + + +def _remap_variables( + node: nodes.ResultNode, uid_gen: guid.SequentialUIDGenerator +) -> nodes.ResultNode: + """Remaps `ColumnId`s in the BFET of a `ResultNode` to produce deterministic UIDs.""" + + result_node, _ = rewrite.remap_variables( + node, map(identifiers.ColumnId, uid_gen.get_uid_stream("bfcol_")) + ) + return typing.cast(nodes.ResultNode, result_node) + + +def _compile_result_node( + root: nodes.ResultNode, uid_gen: guid.SequentialUIDGenerator +) -> str: + # Have to bind schema as the final step before compilation. + root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root)) + selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( + (name, scalar_compiler.scalar_op_compiler.compile_expression(ref)) + for ref, name in root.output_cols + ) + sqlglot_ir = compile_node(root.child, uid_gen).select(selected_cols) + + if root.order_by is not None: + ordering_cols = tuple( + sge.Ordered( + this=scalar_compiler.scalar_op_compiler.compile_expression( + ordering.scalar_expression + ), + desc=ordering.direction.is_ascending is False, + nulls_first=ordering.na_last is False, + ) + for ordering in root.order_by.all_ordering_columns ) - sqlglot_ir = self.compile_node(root.child).select(selected_cols) - - if root.order_by is not None: - ordering_cols = tuple( - sge.Ordered( - this=scalar_compiler.scalar_op_compiler.compile_expression( - ordering.scalar_expression - ), - desc=ordering.direction.is_ascending is False, - nulls_first=ordering.na_last is False, - ) - for ordering in root.order_by.all_ordering_columns + sqlglot_ir = sqlglot_ir.order_by(ordering_cols) + + if root.limit is not None: + sqlglot_ir = sqlglot_ir.limit(root.limit) + + return sqlglot_ir.sql + + +@functools.lru_cache(maxsize=5000) +def compile_node( + node: nodes.BigFrameNode, uid_gen: guid.SequentialUIDGenerator +) -> ir.SQLGlotIR: + """Compiles the given BigFrameNode from bottem-up into SQLGlotIR.""" + bf_to_sqlglot: dict[nodes.BigFrameNode, ir.SQLGlotIR] = {} + child_results: tuple[ir.SQLGlotIR, ...] = () + for current_node in list(node.iter_nodes_topo()): + if current_node.child_nodes == (): + # For leaf node, generates a dumpy child to pass the UID generator. + child_results = tuple([ir.SQLGlotIR(uid_gen=uid_gen)]) + else: + # Child nodes should have been compiled in the reverse topological order. + child_results = tuple( + bf_to_sqlglot[child] for child in current_node.child_nodes ) - sqlglot_ir = sqlglot_ir.order_by(ordering_cols) + result = _compile_node(current_node, *child_results) + bf_to_sqlglot[current_node] = result - if root.limit is not None: - sqlglot_ir = sqlglot_ir.limit(root.limit) + return bf_to_sqlglot[node] - return sqlglot_ir.sql - @functools.lru_cache(maxsize=5000) - def compile_node(self, node: nodes.BigFrameNode) -> ir.SQLGlotIR: - """Compiles node into CompileArrayValue. Caches result.""" - return node.reduce_up( - lambda node, children: self._compile_node(node, *children) - ) +@functools.singledispatch +def _compile_node( + node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR +) -> ir.SQLGlotIR: + """Defines transformation but isn't cached, always use compile_node instead""" + raise ValueError(f"Can't compile unrecognized node: {node}") - @functools.singledispatchmethod - def _compile_node( - self, node: nodes.BigFrameNode, *compiled_children: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - """Defines transformation but isn't cached, always use compile_node instead""" - raise ValueError(f"Can't compile unrecognized node: {node}") - - @_compile_node.register - def compile_readlocal(self, node: nodes.ReadLocalNode, *args) -> ir.SQLGlotIR: - pa_table = node.local_data_source.data - pa_table = pa_table.select([item.source_id for item in node.scan_list.items]) - pa_table = pa_table.rename_columns( - [item.id.sql for item in node.scan_list.items] - ) - offsets = node.offsets_col.sql if node.offsets_col else None - if offsets: - pa_table = pyarrow_utils.append_offsets(pa_table, offsets) - - return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=self.uid_gen) - - @_compile_node.register - def compile_readtable(self, node: nodes.ReadTableNode, *args): - table = node.source.table - return ir.SQLGlotIR.from_table( - table.project_id, - table.dataset_id, - table.table_id, - col_names=[col.source_id for col in node.scan_list.items], - alias_names=[col.id.sql for col in node.scan_list.items], - uid_gen=self.uid_gen, - ) +@_compile_node.register +def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + pa_table = node.local_data_source.data + pa_table = pa_table.select([item.source_id for item in node.scan_list.items]) + pa_table = pa_table.rename_columns([item.id.sql for item in node.scan_list.items]) - @_compile_node.register - def compile_selection( - self, node: nodes.SelectionNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( - (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) - for expr, id in node.input_output_pairs - ) - return child.select(selected_cols) - - @_compile_node.register - def compile_projection( - self, node: nodes.ProjectionNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( - (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) - for expr, id in node.assignments - ) - return child.project(projected_cols) - - @_compile_node.register - def compile_filter( - self, node: nodes.FilterNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - condition = scalar_compiler.scalar_op_compiler.compile_expression( - node.predicate - ) - return child.filter(tuple([condition])) + offsets = node.offsets_col.sql if node.offsets_col else None + if offsets: + pa_table = pyarrow_utils.append_offsets(pa_table, offsets) - @_compile_node.register - def compile_join( - self, node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - conditions = tuple( - ( - typed_expr.TypedExpr( - scalar_compiler.scalar_op_compiler.compile_expression(left), - left.output_type, - ), - typed_expr.TypedExpr( - scalar_compiler.scalar_op_compiler.compile_expression(right), - right.output_type, - ), - ) - for left, right in node.conditions - ) + return ir.SQLGlotIR.from_pyarrow(pa_table, node.schema, uid_gen=child.uid_gen) - return left.join( - right, - join_type=node.type, - conditions=conditions, - joins_nulls=node.joins_nulls, - ) - @_compile_node.register - def compile_isin_join( - self, node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - conditions = ( - typed_expr.TypedExpr( - scalar_compiler.scalar_op_compiler.compile_expression(node.left_col), - node.left_col.output_type, - ), - typed_expr.TypedExpr( - scalar_compiler.scalar_op_compiler.compile_expression(node.right_col), - node.right_col.output_type, - ), - ) +@_compile_node.register +def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR): + table = node.source.table + return ir.SQLGlotIR.from_table( + table.project_id, + table.dataset_id, + table.table_id, + col_names=[col.source_id for col in node.scan_list.items], + alias_names=[col.id.sql for col in node.scan_list.items], + uid_gen=child.uid_gen, + ) - return left.isin_join( - right, - indicator_col=node.indicator_col.sql, - conditions=conditions, - joins_nulls=node.joins_nulls, - ) - @_compile_node.register - def compile_concat( - self, node: nodes.ConcatNode, *children: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - output_ids = [id.sql for id in node.output_ids] - return ir.SQLGlotIR.from_union( - [child.expr for child in children], - output_ids=output_ids, - uid_gen=self.uid_gen, - ) +@_compile_node.register +def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( + (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) + for expr, id in node.input_output_pairs + ) + return child.select(selected_cols) - @_compile_node.register - def compile_explode( - self, node: nodes.ExplodeNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None - columns = tuple(ref.id.sql for ref in node.column_ids) - return child.explode(columns, offsets_col) - - @_compile_node.register - def compile_random_sample( - self, node: nodes.RandomSampleNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - return child.sample(node.fraction) - - @_compile_node.register - def compile_aggregate( - self, node: nodes.AggregateNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - # The BigQuery ordered aggregation cannot support for NULL FIRST/LAST, - # so we need to add extra expressions to enforce the null ordering. - ordering_cols = windows.get_window_order_by( - node.order_by, override_null_order=True - ) - aggregations: tuple[tuple[str, sge.Expression], ...] = tuple( - ( - id.sql, - aggregate_compiler.compile_aggregate( - agg, order_by=ordering_cols if ordering_cols else () - ), - ) - for agg, id in node.aggregations - ) - by_cols: tuple[sge.Expression, ...] = tuple( - scalar_compiler.scalar_op_compiler.compile_expression(by_col) - for by_col in node.by_column_ids - ) - dropna_cols = [] - if node.dropna: - for key, by_col in zip(node.by_column_ids, by_cols): - if node.child.field_by_id[key.id].nullable: - dropna_cols.append(by_col) +@_compile_node.register +def compile_projection(node: nodes.ProjectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple( + (id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr)) + for expr, id in node.assignments + ) + return child.project(projected_cols) - return child.aggregate(aggregations, by_cols, tuple(dropna_cols)) - @_compile_node.register - def compile_window( - self, node: nodes.WindowOpNode, child: ir.SQLGlotIR - ) -> ir.SQLGlotIR: - window_spec = node.window_spec - if node.expression.op.order_independent and window_spec.is_unbounded: - # notably percentile_cont does not support ordering clause - window_spec = window_spec.without_order() +@_compile_node.register +def compile_filter(node: nodes.FilterNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + condition = scalar_compiler.scalar_op_compiler.compile_expression(node.predicate) + return child.filter(tuple([condition])) - window_op = aggregate_compiler.compile_analytic(node.expression, window_spec) - inputs: tuple[sge.Expression, ...] = tuple( - scalar_compiler.scalar_op_compiler.compile_expression( - expression.DerefOp(column) - ) - for column in node.expression.column_references +@_compile_node.register +def compile_join( + node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR +) -> ir.SQLGlotIR: + conditions = tuple( + ( + typed_expr.TypedExpr( + scalar_compiler.scalar_op_compiler.compile_expression(left), + left.output_type, + ), + typed_expr.TypedExpr( + scalar_compiler.scalar_op_compiler.compile_expression(right), + right.output_type, + ), ) - - clauses: list[tuple[sge.Expression, sge.Expression]] = [] - if node.expression.op.skips_nulls and not node.never_skip_nulls: - for column in inputs: - clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null())) - - if window_spec.min_periods and len(inputs) > 0: - if node.expression.op.skips_nulls: - # Most operations do not count NULL values towards min_periods - not_null_columns = [ - sge.Not(this=sge.Is(this=column, expression=sge.Null())) - for column in inputs - ] - # All inputs must be non-null for observation to count - if not not_null_columns: - is_observation_expr: sge.Expression = sge.convert(True) - else: - is_observation_expr = not_null_columns[0] - for expr in not_null_columns[1:]: - is_observation_expr = sge.And( - this=is_observation_expr, expression=expr - ) - is_observation = ir._cast(is_observation_expr, "INT64") - observation_count = windows.apply_window_if_present( - sge.func("SUM", is_observation), window_spec - ) + for left, right in node.conditions + ) + + return left.join( + right, + join_type=node.type, + conditions=conditions, + joins_nulls=node.joins_nulls, + ) + + +@_compile_node.register +def compile_isin_join( + node: nodes.InNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR +) -> ir.SQLGlotIR: + conditions = ( + typed_expr.TypedExpr( + scalar_compiler.scalar_op_compiler.compile_expression(node.left_col), + node.left_col.output_type, + ), + typed_expr.TypedExpr( + scalar_compiler.scalar_op_compiler.compile_expression(node.right_col), + node.right_col.output_type, + ), + ) + + return left.isin_join( + right, + indicator_col=node.indicator_col.sql, + conditions=conditions, + joins_nulls=node.joins_nulls, + ) + + +@_compile_node.register +def compile_concat(node: nodes.ConcatNode, *children: ir.SQLGlotIR) -> ir.SQLGlotIR: + assert len(children) >= 1 + uid_gen = children[0].uid_gen + + output_ids = [id.sql for id in node.output_ids] + return ir.SQLGlotIR.from_union( + [child.expr for child in children], + output_ids=output_ids, + uid_gen=uid_gen, + ) + + +@_compile_node.register +def compile_explode(node: nodes.ExplodeNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + offsets_col = node.offsets_col.sql if (node.offsets_col is not None) else None + columns = tuple(ref.id.sql for ref in node.column_ids) + return child.explode(columns, offsets_col) + + +@_compile_node.register +def compile_random_sample( + node: nodes.RandomSampleNode, child: ir.SQLGlotIR +) -> ir.SQLGlotIR: + return child.sample(node.fraction) + + +@_compile_node.register +def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + # The BigQuery ordered aggregation cannot support for NULL FIRST/LAST, + # so we need to add extra expressions to enforce the null ordering. + ordering_cols = windows.get_window_order_by(node.order_by, override_null_order=True) + aggregations: tuple[tuple[str, sge.Expression], ...] = tuple( + ( + id.sql, + aggregate_compiler.compile_aggregate( + agg, order_by=ordering_cols if ordering_cols else () + ), + ) + for agg, id in node.aggregations + ) + by_cols: tuple[sge.Expression, ...] = tuple( + scalar_compiler.scalar_op_compiler.compile_expression(by_col) + for by_col in node.by_column_ids + ) + + dropna_cols = [] + if node.dropna: + for key, by_col in zip(node.by_column_ids, by_cols): + if node.child.field_by_id[key.id].nullable: + dropna_cols.append(by_col) + + return child.aggregate(aggregations, by_cols, tuple(dropna_cols)) + + +@_compile_node.register +def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR: + window_spec = node.window_spec + if node.expression.op.order_independent and window_spec.is_unbounded: + # notably percentile_cont does not support ordering clause + window_spec = window_spec.without_order() + + window_op = aggregate_compiler.compile_analytic(node.expression, window_spec) + + inputs: tuple[sge.Expression, ...] = tuple( + scalar_compiler.scalar_op_compiler.compile_expression( + expression.DerefOp(column) + ) + for column in node.expression.column_references + ) + + clauses: list[tuple[sge.Expression, sge.Expression]] = [] + if node.expression.op.skips_nulls and not node.never_skip_nulls: + for column in inputs: + clauses.append((sge.Is(this=column, expression=sge.Null()), sge.Null())) + + if window_spec.min_periods and len(inputs) > 0: + if node.expression.op.skips_nulls: + # Most operations do not count NULL values towards min_periods + not_null_columns = [ + sge.Not(this=sge.Is(this=column, expression=sge.Null())) + for column in inputs + ] + # All inputs must be non-null for observation to count + if not not_null_columns: + is_observation_expr: sge.Expression = sge.convert(True) else: - # Operations like count treat even NULLs as valid observations - # for the sake of min_periods notnull is just used to convert - # null values to non-null (FALSE) values to be counted. - is_observation = ir._cast( - sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), - "INT64", - ) - observation_count = windows.apply_window_if_present( - sge.func("COUNT", is_observation), window_spec - ) - - clauses.append( - ( - observation_count < sge.convert(window_spec.min_periods), - sge.Null(), - ) + is_observation_expr = not_null_columns[0] + for expr in not_null_columns[1:]: + is_observation_expr = sge.And( + this=is_observation_expr, expression=expr + ) + is_observation = ir._cast(is_observation_expr, "INT64") + observation_count = windows.apply_window_if_present( + sge.func("SUM", is_observation), window_spec + ) + else: + # Operations like count treat even NULLs as valid observations + # for the sake of min_periods notnull is just used to convert + # null values to non-null (FALSE) values to be counted. + is_observation = ir._cast( + sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), + "INT64", + ) + observation_count = windows.apply_window_if_present( + sge.func("COUNT", is_observation), window_spec + ) + + clauses.append( + ( + observation_count < sge.convert(window_spec.min_periods), + sge.Null(), ) - if clauses: - when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] - window_op = sge.Case(ifs=when_expressions, default=window_op) - - # TODO: check if we can directly window the expression. - return child.window( - window_op=window_op, - output_column_id=node.output_name.sql, ) + if clauses: + when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses] + window_op = sge.Case(ifs=when_expressions, default=window_op) + + # TODO: check if we can directly window the expression. + return child.window( + window_op=window_op, + output_column_id=node.output_name.sql, + ) def _replace_unsupported_ops(node: nodes.BigFrameNode): diff --git a/bigframes/session/direct_gbq_execution.py b/bigframes/session/direct_gbq_execution.py index d76a1a7630..748c43e66c 100644 --- a/bigframes/session/direct_gbq_execution.py +++ b/bigframes/session/direct_gbq_execution.py @@ -40,9 +40,7 @@ def __init__( ): self.bqclient = bqclient self._compile_fn = ( - compile.compile_sql - if compiler == "ibis" - else sqlglot.SQLGlotCompiler()._compile_sql + compile.compile_sql if compiler == "ibis" else sqlglot.compile_sql ) self._publisher = publisher diff --git a/bigframes/testing/compiler_session.py b/bigframes/testing/compiler_session.py index 289b2600fd..b248f37cfc 100644 --- a/bigframes/testing/compiler_session.py +++ b/bigframes/testing/compiler_session.py @@ -16,7 +16,7 @@ import typing import bigframes.core -import bigframes.core.compile.sqlglot as sqlglot +import bigframes.core.compile as compile import bigframes.session.executor @@ -24,7 +24,7 @@ class SQLCompilerExecutor(bigframes.session.executor.Executor): """Executor for SQL compilation using sqlglot.""" - compiler = sqlglot + compiler = compile.sqlglot def to_sql( self, @@ -38,9 +38,9 @@ def to_sql( # Compared with BigQueryCachingExecutor, SQLCompilerExecutor skips # caching the subtree. - return self.compiler.SQLGlotCompiler().compile( - array_value.node, ordered=ordered - ) + return self.compiler.compile_sql( + compile.CompileRequest(array_value.node, sort_rows=ordered) + ).sql def execute( self, diff --git a/tests/unit/core/compile/sqlglot/test_compile_random_sample.py b/tests/unit/core/compile/sqlglot/test_compile_random_sample.py index 6e333f0421..486d994f87 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_random_sample.py +++ b/tests/unit/core/compile/sqlglot/test_compile_random_sample.py @@ -16,7 +16,7 @@ from bigframes.core import nodes import bigframes.core as core -import bigframes.core.compile.sqlglot as sqlglot +import bigframes.core.compile as compile pytest.importorskip("pytest_snapshot") @@ -31,5 +31,5 @@ def test_compile_random_sample( operation, this test constructs the node directly and then compiles it to SQL. """ node = nodes.RandomSampleNode(scalar_types_array_value.node, fraction=0.1) - sql = sqlglot.compiler.SQLGlotCompiler().compile(node) + sql = compile.sqlglot.compile_sql(compile.CompileRequest(node, sort_rows=True)).sql snapshot.assert_match(sql, "out.sql") From ee59f86db76544f5ce00b9cc05479e530e27113e Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Sat, 8 Nov 2025 00:53:14 +0000 Subject: [PATCH 2/2] remove label asserting --- tests/unit/session/test_io_bigquery.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/session/test_io_bigquery.py b/tests/unit/session/test_io_bigquery.py index 41f3755f13..4349c1b6ee 100644 --- a/tests/unit/session/test_io_bigquery.py +++ b/tests/unit/session/test_io_bigquery.py @@ -156,7 +156,6 @@ def test_create_job_configs_labels_length_limit_met(): ) assert labels is not None - assert len(labels) == 56 assert "dataframe-max" in labels.values() assert "dataframe-head" not in labels.values() assert "bigframes-api" in labels.keys()