diff --git a/doc/source/data/api/expressions.rst b/doc/source/data/api/expressions.rst index b5966c1636c2..1bb4d2488759 100644 --- a/doc/source/data/api/expressions.rst +++ b/doc/source/data/api/expressions.rst @@ -17,6 +17,7 @@ Public API :nosignatures: :toctree: doc/ + star col lit udf @@ -37,4 +38,5 @@ instantiate them directly, but you may encounter them when working with expressi LiteralExpr BinaryExpr UnaryExpr - UDFExpr \ No newline at end of file + UDFExpr + StarColumnsExpr \ No newline at end of file diff --git a/python/ray/data/_expression_evaluator.py b/python/ray/data/_expression_evaluator.py deleted file mode 100644 index 506d3a523cbc..000000000000 --- a/python/ray/data/_expression_evaluator.py +++ /dev/null @@ -1,183 +0,0 @@ -from __future__ import annotations - -import operator -from typing import Any, Callable, Dict, Union - -import numpy as np -import pandas as pd -import pyarrow as pa -import pyarrow.compute as pc - -from ray.data.block import DataBatch -from ray.data.expressions import ( - AliasExpr, - BinaryExpr, - ColumnExpr, - Expr, - LiteralExpr, - Operation, - UDFExpr, - UnaryExpr, -) - - -def _pa_is_in(left: Any, right: Any) -> Any: - if not isinstance(right, (pa.Array, pa.ChunkedArray)): - right = pa.array(right.as_py() if isinstance(right, pa.Scalar) else right) - return pc.is_in(left, right) - - -_PANDAS_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { - Operation.ADD: operator.add, - Operation.SUB: operator.sub, - Operation.MUL: operator.mul, - Operation.DIV: operator.truediv, - Operation.FLOORDIV: operator.floordiv, - Operation.GT: operator.gt, - Operation.LT: operator.lt, - Operation.GE: operator.ge, - Operation.LE: operator.le, - Operation.EQ: operator.eq, - Operation.NE: operator.ne, - Operation.AND: operator.and_, - Operation.OR: operator.or_, - Operation.NOT: operator.not_, - Operation.IS_NULL: pd.isna, - Operation.IS_NOT_NULL: pd.notna, - Operation.IN: lambda left, right: left.is_in(right), - Operation.NOT_IN: lambda left, right: ~left.is_in(right), -} - - -def _is_pa_string_type(t: pa.DataType) -> bool: - return pa.types.is_string(t) or pa.types.is_large_string(t) - - -def _is_pa_string_like(x: Union[pa.Array, pa.ChunkedArray]) -> bool: - t = x.type - if pa.types.is_dictionary(t): - t = t.value_type - return _is_pa_string_type(t) - - -def _pa_decode_dict_string_array(x: Union[pa.Array, pa.ChunkedArray]) -> Any: - """Convert Arrow dictionary-encoded string arrays to regular string arrays. - - Dictionary encoding stores strings as indices into a dictionary of unique values. - This function converts them back to regular string arrays for string operations. - - Example: - # Input: pa.array(['a', 'b']).dictionary_encode() - # -- dictionary: ["a", "b"] - # -- indices: [0, 1] - # Output: regular string array ["a", "b"] - Args: - x: The input array to convert. - Returns: - The converted string array. - """ - if pa.types.is_dictionary(x.type) and _is_pa_string_type(x.type.value_type): - return pc.cast(x, pa.string()) - return x - - -def _to_pa_string_input(x: Any) -> Any: - if isinstance(x, str): - return pa.scalar(x) - elif _is_pa_string_like(x) and isinstance(x, (pa.Array, pa.ChunkedArray)): - x = _pa_decode_dict_string_array(x) - else: - raise - return x - - -def _pa_add_or_concat(left: Any, right: Any) -> Any: - # If either side is string-like, perform string concatenation. - if ( - isinstance(left, str) - or isinstance(right, str) - or (isinstance(left, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(left)) - or ( - isinstance(right, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(right) - ) - ): - left_input = _to_pa_string_input(left) - right_input = _to_pa_string_input(right) - return pc.binary_join_element_wise(left_input, right_input, "") - return pc.add(left, right) - - -_ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { - Operation.ADD: _pa_add_or_concat, - Operation.SUB: pc.subtract, - Operation.MUL: pc.multiply, - Operation.DIV: pc.divide, - Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)), - Operation.GT: pc.greater, - Operation.LT: pc.less, - Operation.GE: pc.greater_equal, - Operation.LE: pc.less_equal, - Operation.EQ: pc.equal, - Operation.NE: pc.not_equal, - Operation.AND: pc.and_kleene, - Operation.OR: pc.or_kleene, - Operation.NOT: pc.invert, - Operation.IS_NULL: pc.is_null, - Operation.IS_NOT_NULL: pc.is_valid, - Operation.IN: _pa_is_in, - Operation.NOT_IN: lambda left, right: pc.invert(_pa_is_in(left, right)), -} - - -def _eval_expr_recursive( - expr: "Expr", batch: DataBatch, ops: Dict["Operation", Callable[..., Any]] -) -> Any: - """Generic recursive expression evaluator.""" - # TODO: Separate unresolved expressions (arbitrary AST with unresolved refs) - # and resolved expressions (bound to a schema) for better error handling - - if isinstance(expr, ColumnExpr): - return batch[expr.name] - if isinstance(expr, LiteralExpr): - return expr.value - if isinstance(expr, BinaryExpr): - return ops[expr.op]( - _eval_expr_recursive(expr.left, batch, ops), - _eval_expr_recursive(expr.right, batch, ops), - ) - if isinstance(expr, UnaryExpr): - # TODO: Use Visitor pattern here and store ops in shared state. - return ops[expr.op](_eval_expr_recursive(expr.operand, batch, ops)) - - if isinstance(expr, UDFExpr): - args = [_eval_expr_recursive(arg, batch, ops) for arg in expr.args] - kwargs = { - k: _eval_expr_recursive(v, batch, ops) for k, v in expr.kwargs.items() - } - result = expr.fn(*args, **kwargs) - - # Can't perform type validation for unions if python version is < 3.10 - if not isinstance(result, (pd.Series, np.ndarray, pa.Array, pa.ChunkedArray)): - function_name = expr.fn.__name__ - raise TypeError( - f"UDF '{function_name}' returned invalid type {type(result).__name__}. " - f"Expected type (pandas.Series, numpy.ndarray, pyarrow.Array, or pyarrow.ChunkedArray)" - ) - - return result - - if isinstance(expr, AliasExpr): - # The renaming of the column is handled in the project op planner stage. - return _eval_expr_recursive(expr.expr, batch, ops) - - raise TypeError(f"Unsupported expression node: {type(expr).__name__}") - - -def eval_expr(expr: "Expr", batch: DataBatch) -> Any: - """Recursively evaluate *expr* against a batch of the appropriate type.""" - if isinstance(batch, pd.DataFrame): - return _eval_expr_recursive(expr, batch, _PANDAS_EXPR_OPS_MAP) - elif isinstance(batch, pa.Table): - return _eval_expr_recursive(expr, batch, _ARROW_EXPR_OPS_MAP) - else: - raise TypeError(f"Unsupported batch type: {type(batch).__name__}") diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index d706e5c7ba63..bedf1d027c50 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -58,6 +58,7 @@ _MIN_PYARROW_VERSION_TO_NUMPY_ZERO_COPY_ONLY = parse_version("13.0.0") +_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub" # Set the max chunk size in bytes for Arrow to Batches conversion in @@ -221,7 +222,7 @@ def fill_column(self, name: str, value: Any) -> Block: array = pyarrow.nulls(len(self._table), type=type) array = pc.fill_null(array, value) - return self._table.append_column(name, array) + return self.upsert_column(name, array) @classmethod def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor": @@ -453,7 +454,9 @@ def filter(self, predicate_expr: "Expr") -> "pyarrow.Table": if self._table.num_rows == 0: return self._table - from ray.data._expression_evaluator import eval_expr + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + eval_expr, + ) # Evaluate the expression to get a boolean mask mask = eval_expr(predicate_expr, self._table) diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 65d52fc80083..b37be53c0338 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -23,7 +23,10 @@ import ray from ray._private.arrow_utils import get_pyarrow_version -from ray.data._internal.arrow_block import ArrowBlockAccessor +from ray.data._internal.arrow_block import ( + _BATCH_SIZE_PRESERVING_STUB_COL_NAME, + ArrowBlockAccessor, +) from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import ( @@ -104,9 +107,6 @@ PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024 -_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub" - - class _ParquetFragment: """This wrapper class is created to avoid utilizing `ParquetFileFragment` original serialization protocol that actually does network RPCs during serialization diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 63cf35410237..46b0fdc65848 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -7,7 +7,7 @@ from ray.data._internal.logical.interfaces import LogicalOperator from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne from ray.data.block import UserDefinedFunction -from ray.data.expressions import Expr +from ray.data.expressions import Expr, StarColumnsExpr from ray.data.preprocessor import Preprocessor logger = logging.getLogger(__name__) @@ -268,16 +268,12 @@ def can_modify_num_rows(self) -> bool: class Project(AbstractMap): - """Logical operator for select_columns.""" + """Logical operator for select_columns, with_column, rename_columns.""" def __init__( self, input_op: LogicalOperator, - cols: Optional[List[str]] = None, - cols_rename: Optional[Dict[str, str]] = None, - exprs: Optional[ - Dict[str, "Expr"] - ] = None, # TODO Remove cols and cols_rename and replace them with corresponding exprs + exprs: list["Expr"], compute: Optional[ComputeStrategy] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): @@ -288,30 +284,23 @@ def __init__( compute=compute, ) self._batch_size = None - self._cols = cols - self._cols_rename = cols_rename self._exprs = exprs self._batch_format = "pyarrow" self._zero_copy_batch = True - if exprs is not None: - # Validate that all values are expressions - for name, expr in exprs.items(): - if not isinstance(expr, Expr): - raise TypeError( - f"Expected Expr for column '{name}', got {type(expr)}" - ) + for expr in self._exprs: + if expr.name is None and not isinstance(expr, StarColumnsExpr): + raise TypeError( + "All Project expressions must be named (use .alias(name) or col(name)), " + "or be a star() expression." + ) - @property - def cols(self) -> Optional[List[str]]: - return self._cols - - @property - def cols_rename(self) -> Optional[Dict[str, str]]: - return self._cols_rename + def has_all_columns_expr(self) -> bool: + """Check if this projection contains a star() expression.""" + return any(isinstance(expr, StarColumnsExpr) for expr in self._exprs) @property - def exprs(self) -> Optional[Dict[str, "Expr"]]: + def exprs(self) -> List["Expr"]: return self._exprs def can_modify_num_rows(self) -> bool: diff --git a/python/ray/data/_internal/logical/rules/projection_pushdown.py b/python/ray/data/_internal/logical/rules/projection_pushdown.py index df6a4ab6905d..70f839aab64c 100644 --- a/python/ray/data/_internal/logical/rules/projection_pushdown.py +++ b/python/ray/data/_internal/logical/rules/projection_pushdown.py @@ -1,6 +1,4 @@ -import logging -from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Any, List, Optional, Set, Tuple from ray.data._internal.logical.interfaces import ( LogicalOperator, @@ -9,304 +7,566 @@ Rule, ) from ray.data._internal.logical.operators.map_operator import Project -from ray.data._internal.logical.operators.read_operator import Read -from ray.data.expressions import Expr +from ray.data.expressions import ( + AliasExpr, + BinaryExpr, + ColumnExpr, + Expr, + LiteralExpr, + StarColumnsExpr, + UDFExpr, + UnaryExpr, + _ExprVisitor, +) -logger = logging.getLogger(__name__) +class _ColumnReferenceCollector(_ExprVisitor): + """Visitor that collects all column references from expression trees. -@dataclass(frozen=True) -class _ProjectSpec: - cols: Optional[List[str]] - cols_remap: Optional[Dict[str, str]] - exprs: Optional[Dict[str, Expr]] + This visitor traverses expression trees and accumulates column names + referenced in ColumnExpr nodes. + """ + def __init__(self): + """Initialize with an empty set of referenced columns.""" + self.referenced_columns: Set[str] = set() -class ProjectionPushdown(Rule): - """Optimization rule that pushes down projections across the graph. + def visit(self, expr: Expr) -> Any: + """Visit an expression node and dispatch to the appropriate method. - This rule looks for `Project` operators that are immediately - preceded by a `Read` operator and sets the - projected columns on the `Read` operator. + Extends the base visitor to handle StarColumnsExpr which is not + part of the base _ExprVisitor interface. - If there are redundant Project operators, it removes the `Project` operator from - the graph. - """ + Args: + expr: The expression to visit. - def apply(self, plan: LogicalPlan) -> LogicalPlan: - dag = plan.dag - new_dag = dag._apply_transform(self._pushdown_project) + Returns: + None (only collects columns as a side effect). + """ + if isinstance(expr, StarColumnsExpr): + # StarColumnsExpr doesn't reference specific columns + return None + return super().visit(expr) - return LogicalPlan(new_dag, plan.context) if dag is not new_dag else plan + def visit_column(self, expr: ColumnExpr) -> Any: + """Visit a column expression and collect its name. - @classmethod - def _pushdown_project(cls, op: LogicalOperator) -> LogicalOperator: - if isinstance(op, Project): - # Push-down projections into read op - if cls._supports_projection_pushdown(op): - project_op: Project = op - target_op: LogicalOperatorSupportsProjectionPushdown = ( - op.input_dependency - ) + Args: + expr: The column expression. - return cls._try_combine(target_op, project_op) + Returns: + None (only collects columns as a side effect). + """ + self.referenced_columns.add(expr.name) - # Otherwise, fuse projections into a single op - elif isinstance(op.input_dependency, Project): - outer_op: Project = op - inner_op: Project = op.input_dependency + def visit_literal(self, expr: LiteralExpr) -> Any: + """Visit a literal expression (no columns to collect). - return cls._fuse(inner_op, outer_op) + Args: + expr: The literal expression. - return op + Returns: + None. + """ + # Literals don't reference any columns + pass - @classmethod - def _supports_projection_pushdown(cls, op: Project) -> bool: - # NOTE: Currently only projecting into Parquet is supported - input_op = op.input_dependency - return ( - isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) - and input_op.supports_projection_pushdown() - ) + def visit_binary(self, expr: BinaryExpr) -> Any: + """Visit a binary expression and collect from both operands. - @staticmethod - def _fuse(inner_op: Project, outer_op: Project) -> Project: - # Combine expressions from both operators - combined_exprs = _combine_expressions(inner_op.exprs, outer_op.exprs) - - # Only combine projection specs if there are no expressions - # When expressions are present, they take precedence - if combined_exprs: - # When expressions are present, preserve column operations from outer operation - # The logical order is: expressions first, then column operations - outer_cols = outer_op.cols - outer_cols_rename = outer_op.cols_rename - - # If outer operation has no column operations, fall back to inner operation - if outer_cols is None and outer_cols_rename is None: - outer_cols = inner_op.cols - outer_cols_rename = inner_op.cols_rename - - return Project( - inner_op.input_dependency, - cols=outer_cols, - cols_rename=outer_cols_rename, - exprs=combined_exprs, - # Give precedence to outer operator's ray_remote_args - ray_remote_args={ - **inner_op._ray_remote_args, - **outer_op._ray_remote_args, - }, - ) - else: - # Fall back to original behavior for column-only projections - inner_op_spec = _get_projection_spec(inner_op) - outer_op_spec = _get_projection_spec(outer_op) + Args: + expr: The binary expression. - new_spec = _combine_projection_specs( - prev_spec=inner_op_spec, new_spec=outer_op_spec - ) + Returns: + None (only collects columns as a side effect). + """ + self.visit(expr.left) + self.visit(expr.right) - return Project( - inner_op.input_dependency, - cols=new_spec.cols, - cols_rename=new_spec.cols_remap, - exprs=None, - ray_remote_args={ - **inner_op._ray_remote_args, - **outer_op._ray_remote_args, - }, - ) + def visit_unary(self, expr: UnaryExpr) -> Any: + """Visit a unary expression and collect from its operand. - @staticmethod - def _try_combine( - target_op: LogicalOperatorSupportsProjectionPushdown, - project_op: Project, - ) -> LogicalOperator: - # For now, don't push down expressions into `Read` operators - # Only handle traditional column projections - if project_op.exprs: - # Cannot push expressions into `Read`, return unchanged - return project_op - - target_op_spec = _get_projection_spec(target_op) - project_op_spec = _get_projection_spec(project_op) - - new_spec = _combine_projection_specs( - prev_spec=target_op_spec, new_spec=project_op_spec - ) + Args: + expr: The unary expression. - logger.debug( - f"Pushing projection down into read operation " - f"projection columns = {new_spec.cols} (before: {target_op_spec.cols}), " - f"remap = {new_spec.cols_remap} (before: {target_op_spec.cols_remap})" - ) + Returns: + None (only collects columns as a side effect). + """ + self.visit(expr.operand) - return target_op.apply_projection(new_spec.cols) + def visit_udf(self, expr: UDFExpr) -> Any: + """Visit a UDF expression and collect from all arguments. + Args: + expr: The UDF expression. -def _combine_expressions( - inner_exprs: Optional[Dict[str, Expr]], outer_exprs: Optional[Dict[str, Expr]] -) -> Optional[Dict[str, Expr]]: - """Combine expressions from two Project operators. + Returns: + None (only collects columns as a side effect). + """ + for arg in expr.args: + self.visit(arg) + for value in expr.kwargs.values(): + self.visit(value) - Args: - inner_exprs: Expressions from the inner (upstream) Project operator - outer_exprs: Expressions from the outer (downstream) Project operator + def visit_alias(self, expr: AliasExpr) -> Any: + """Visit an alias expression and collect from its inner expression. - Returns: - Combined dictionary of expressions, or None if no expressions + Args: + expr: The alias expression. + + Returns: + None (only collects columns as a side effect). + """ + self.visit(expr.expr) + + def visit_download(self, expr: "Expr") -> Any: + """Visit a download expression (no columns to collect). + + Args: + expr: The download expression. + + Returns: + None. + """ + # DownloadExpr doesn't reference any columns in the projection pushdown context + pass + + +def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]: + """ + Extract all column names referenced by the given expressions. + + Recursively traverses expression trees to find all ColumnExpr nodes + and collects their names. + + Example: For expression "col1 + col2", returns {"col1", "col2"} """ - if not inner_exprs and not outer_exprs: + # If any expression is star(), we need all columns + if any(isinstance(expr, StarColumnsExpr) for expr in exprs): + # TODO (goutam): Instead of using None to refer to All columns, resolve the AST against the schema. + # https://github.com/ray-project/ray/issues/57720 return None - combined = {} + collector = _ColumnReferenceCollector() + for expr in exprs or []: + collector.visit(expr) + return collector.referenced_columns + - # Add expressions from inner operator - if inner_exprs: - combined.update(inner_exprs) +class _ColumnRewriter(_ExprVisitor): + """Visitor that rewrites column references in expression trees. - # Add expressions from outer operator - if outer_exprs: - combined.update(outer_exprs) + This visitor traverses expression trees and substitutes column references + according to a provided substitution map, preserving the structure of the tree. + """ - return combined if combined else None + def __init__(self, column_substitutions: dict[str, Expr]): + """Initialize with a column substitution map. + + Args: + column_substitutions: Mapping from column names to replacement expressions. + """ + self.column_substitutions = column_substitutions + + def visit(self, expr: Expr) -> Expr: + """Visit an expression node and return the rewritten expression. + + Extends the base visitor to handle StarColumnsExpr which is not + part of the base _ExprVisitor interface. + + Args: + expr: The expression to visit. + + Returns: + The rewritten expression. + """ + if isinstance(expr, StarColumnsExpr): + # StarColumnsExpr is not rewritten + return expr + return super().visit(expr) + + def visit_column(self, expr: ColumnExpr) -> Expr: + """Visit a column expression and potentially substitute it. + + Args: + expr: The column expression. + + Returns: + The substituted expression or the original if no substitution exists. + """ + substitution = self.column_substitutions.get(expr.name) + if substitution is not None: + # Unwrap aliases to get the actual expression + return ( + substitution.expr + if isinstance(substitution, AliasExpr) + else substitution + ) + return expr + def visit_literal(self, expr: LiteralExpr) -> Expr: + """Visit a literal expression (no rewriting needed). -def _get_projection_spec(op: Union[Project, Read]) -> _ProjectSpec: - assert op is not None + Args: + expr: The literal expression. - if isinstance(op, Project): - return _ProjectSpec( - cols=op.cols, - cols_remap=op.cols_rename, - exprs=op.exprs, - ) - elif isinstance(op, Read): - assert op.supports_projection_pushdown() + Returns: + The original literal expression. + """ + return expr - return _ProjectSpec( - cols=op.get_current_projection(), - cols_remap=None, - exprs=None, + def visit_binary(self, expr: BinaryExpr) -> Expr: + """Visit a binary expression and rewrite its operands. + + Args: + expr: The binary expression. + + Returns: + A new binary expression with rewritten operands. + """ + return type(expr)( + expr.op, + self.visit(expr.left), + self.visit(expr.right), ) - else: - raise ValueError( - f"Operation doesn't have projection spec (supported Project, " - f"Read, got: {op.__class__})" + + def visit_unary(self, expr: UnaryExpr) -> Expr: + """Visit a unary expression and rewrite its operand. + + Args: + expr: The unary expression. + + Returns: + A new unary expression with rewritten operand. + """ + return type(expr)(expr.op, self.visit(expr.operand)) + + def visit_udf(self, expr: UDFExpr) -> Expr: + """Visit a UDF expression and rewrite its arguments. + + Args: + expr: The UDF expression. + + Returns: + A new UDF expression with rewritten arguments. + """ + new_args = [self.visit(arg) for arg in expr.args] + new_kwargs = {key: self.visit(value) for key, value in expr.kwargs.items()} + return type(expr)( + fn=expr.fn, data_type=expr.data_type, args=new_args, kwargs=new_kwargs ) + def visit_alias(self, expr: AliasExpr) -> Expr: + """Visit an alias expression and rewrite its inner expression. -def _combine_projection_specs( - prev_spec: _ProjectSpec, new_spec: _ProjectSpec -) -> _ProjectSpec: - combined_cols_remap = _combine_columns_remap( - prev_spec.cols_remap, - new_spec.cols_remap, - ) + Args: + expr: The alias expression. - # Validate resulting remapping against existing projection (if any) - _validate(combined_cols_remap, prev_spec.cols) + Returns: + A new alias expression with rewritten inner expression and preserved name. + """ + return self.visit(expr.expr).alias(expr.name) - new_projection_cols: Optional[List[str]] + def visit_download(self, expr: "Expr") -> Expr: + """Visit a download expression (no rewriting needed). - if prev_spec.cols is None and new_spec.cols is None: - # If both projections are unset, resulting is unset - new_projection_cols = None - elif prev_spec.cols is not None and new_spec.cols is None: - # If previous projection is set, but the new unset -- fallback to - # existing projection - new_projection_cols = prev_spec.cols - else: - # If new is set (and previous is either set or not) - # - Reconcile new projection - # - Project combined column remapping - assert new_spec.cols is not None - - new_projection_cols = new_spec.cols - - # Remap new projected columns into the schema before remapping (from the - # previous spec) - if prev_spec.cols_remap and new_projection_cols: - # Inverse remapping - inv_cols_remap = {v: k for k, v in prev_spec.cols_remap.items()} - new_projection_cols = [ - inv_cols_remap.get(col, col) for col in new_projection_cols - ] - - prev_cols_set = set(prev_spec.cols or []) - new_cols_set = set(new_projection_cols or []) - - # Validate new projection is a proper subset of the previous one - if prev_cols_set and new_cols_set and not new_cols_set.issubset(prev_cols_set): - raise ValueError( - f"Selected columns '{new_cols_set}' needs to be a subset of " - f"'{prev_cols_set}'" - ) + Args: + expr: The download expression. - # Project remaps to only map relevant columns - if new_projection_cols is not None and combined_cols_remap is not None: - projected_cols_remap = { - k: v for k, v in combined_cols_remap.items() if k in new_projection_cols - } - else: - projected_cols_remap = combined_cols_remap + Returns: + The original download expression. + """ + return expr - # Combine expressions from both specs - combined_exprs = _combine_expressions(prev_spec.exprs, new_spec.exprs) - return _ProjectSpec( - cols=new_projection_cols, cols_remap=projected_cols_remap, exprs=combined_exprs - ) +def _rewrite_column_references( + expr: Expr, column_substitutions: dict[str, Expr] +) -> Expr: + """ + Rewrite an expression by substituting column references. + Recursively replaces ColumnExpr nodes according to the substitution map. + Preserves the structure of the expression tree. -def _combine_columns_remap( - prev_remap: Optional[Dict[str, str]], new_remap: Optional[Dict[str, str]] -) -> Optional[Dict[str, str]]: + Example: If column_substitutions = {"col1": col2_expr}, then + "col1 + 10" becomes "col2 + 10" - if not new_remap and not prev_remap: - return None + Args: + expr: The expression to rewrite. + column_substitutions: Mapping from column names to replacement expressions. + + Returns: + The rewritten expression. + """ + rewriter = _ColumnRewriter(column_substitutions) + return rewriter.visit(expr) - new_remap = new_remap or {} - base_remap = prev_remap or {} - filtered_new_remap = dict(new_remap) - # Apply new remapping to the base remap - updated_base_remap = { - # NOTE: We're removing corresponding chained mapping from the remap - k: filtered_new_remap.pop(v, v) - for k, v in base_remap.items() +def _try_wrap_expression_with_alias(expr: Expr, target_name: str) -> Expr: + """ + Ensure an expression outputs with the specified name. + + If the expression already has the target name, returns it unchanged. + Otherwise, wraps it with an alias to produce the target name. + """ + if expr.name == target_name: + return expr + if isinstance(expr, AliasExpr): + # Re-alias the unwrapped expression + return expr.expr.alias(target_name) + return expr.alias(target_name) + + +def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]: + """ + Check if an expression is a simple column rename. + + Returns (source_name, dest_name) if the expression is of form: + col("source").alias("dest") + where source != dest. + + Returns None for other expression types. + """ + if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr): + dest_name = expr.name + source_name = expr.expr.name + if source_name != dest_name: + return source_name, dest_name + return None + + +def _try_fuse_consecutive_projects( + upstream_project: Project, downstream_project: Project +) -> Optional[Project]: + """ + Attempt to merge two consecutive Project operations into one. + + Updated to handle StarColumnsExpr instead of preserve_existing flag. + """ + from ray.data.expressions import StarColumnsExpr + + # Check if projects have star() + upstream_has_all = upstream_project.has_all_columns_expr() + downstream_has_all = downstream_project.has_all_columns_expr() + + # Step 1: Analyze what the upstream project produces + upstream_output_columns = { + expr.name + for expr in upstream_project.exprs + if not isinstance(expr, StarColumnsExpr) + } + upstream_column_definitions = { + expr.name: _try_wrap_expression_with_alias(expr, expr.name) + for expr in upstream_project.exprs + if not isinstance(expr, StarColumnsExpr) } - resolved_remap = dict(updated_base_remap) - resolved_remap.update(filtered_new_remap) + # Step 2: Identify columns removed by upstream renames + # When "col1" is renamed to "col2" and "col1" is not in the output, + # then "col1" is effectively removed and cannot be accessed downstream. + columns_removed_by_renames: Set[str] = set() + for expr in upstream_project.exprs: + if isinstance(expr, StarColumnsExpr): + continue + rename_pair = _extract_simple_rename(expr) + if rename_pair is not None: + source_name, _ = rename_pair + if source_name not in upstream_output_columns: + columns_removed_by_renames.add(source_name) + + # Step 3: Validate and rewrite downstream expressions + rewritten_downstream_exprs: List[Expr] = [] + for expr in downstream_project.exprs: + if isinstance(expr, StarColumnsExpr): + # star() passes through in fusion + rewritten_downstream_exprs.append(expr) + continue + + # Find which columns this expression references + referenced_columns = _collect_referenced_columns([expr]) + + # Separate columns: produced by upstream vs. pass-through from original input + columns_from_original_input = referenced_columns - ( + referenced_columns & upstream_output_columns + ) - return resolved_remap + # Validate that downstream can access the columns it needs + if not upstream_has_all: + # Upstream is a selection: only upstream outputs are visible + if columns_from_original_input: + # Fusion not possible: downstream needs columns not in upstream output + return None + else: + # Upstream preserves existing: pass-through columns are allowed, + # except those explicitly removed by renames + if any( + col in columns_removed_by_renames for col in columns_from_original_input + ): + # Fusion not possible: downstream needs a removed column + return None + + # Rewrite the expression to use upstream's definitions + rewritten_expr = _rewrite_column_references(expr, upstream_column_definitions) + rewritten_downstream_exprs.append( + _try_wrap_expression_with_alias(rewritten_expr, expr.name) + ) + # Step 4: Build the fused project based on downstream's behavior + if not downstream_has_all: + # Downstream is a selection: output only what downstream specifies + return Project( + upstream_project.input_dependency, + exprs=rewritten_downstream_exprs, + ray_remote_args=downstream_project._ray_remote_args, + ) -def _validate(remap: Optional[Dict[str, str]], projection_cols: Optional[List[str]]): - if not remap: - return + # Step 5: Downstream has star(): merge both projections + downstream_output_columns = { + expr.name + for expr in downstream_project.exprs + if not isinstance(expr, StarColumnsExpr) + } - # Verify that the remapping is a proper bijection (ie no - # columns are renamed into the same new name) - prev_names_map = {} - for prev_name, new_name in remap.items(): - if new_name in prev_names_map: - raise ValueError( - f"Identified projections with conflict in renaming: '{new_name}' " - f"is mapped from multiple sources: '{prev_names_map[new_name]}' " - f"and '{prev_name}'." + # Start with upstream's column definitions and ordering + column_definitions = { + expr.name: _try_wrap_expression_with_alias(expr, expr.name) + for expr in upstream_project.exprs + if not isinstance(expr, StarColumnsExpr) + } + column_order = [ + expr.name + for expr in upstream_project.exprs + if not isinstance(expr, StarColumnsExpr) + ] + + # Apply downstream's transformations + # + # Example scenario: + # Upstream outputs: {a: col("x"), b: col("y") + 1, c: col("z")} + # Downstream exprs: [col("a").alias("d"), col("b") + 2] + # + # After this loop: + # - "a" is renamed to "d" (source "a" removed if not in downstream output) + # - "b" is overwritten with a new definition: (col("y") + 1) + 2 + # - "c" passes through unchanged from upstream + for expr in downstream_project.exprs: + if isinstance(expr, StarColumnsExpr): + continue + + column_name = expr.name + rename_pair = _extract_simple_rename(expr) + + if rename_pair is not None: + # Handle rename: source -> dest + # Example: col("a").alias("d") means rename "a" to "d" + source_name, dest_name = rename_pair + resolved_expr = upstream_column_definitions.get(source_name, expr) + column_definitions[dest_name] = _try_wrap_expression_with_alias( + resolved_expr, dest_name ) - prev_names_map[new_name] = prev_name + # If source is not kept by downstream, remove it from the output + # Example: After renaming "a" to "d", if "a" is not in downstream outputs, + # we remove "a" from the final output (only "d" remains) + if ( + source_name not in downstream_output_columns + and source_name in column_definitions + ): + del column_definitions[source_name] + + # Update column ordering: replace source with dest or append dest + # Example: If order was ["a", "b", "c"] and "a" renamed to "d", + # order becomes ["d", "b", "c"] (maintaining position) + if ( + source_name in column_order + and source_name not in downstream_output_columns + ): + idx = column_order.index(source_name) + column_order[idx] = dest_name + elif dest_name not in column_order: + column_order.append(dest_name) + continue + + # Handle non-rename: add or overwrite column definition + # Example: If downstream has col("b") + 2, and upstream had b: col("y") + 1, + # we rewrite to: (col("y") + 1) + 2, collapsing both transformations + rewritten_expr = _rewrite_column_references(expr, upstream_column_definitions) + column_definitions[column_name] = _try_wrap_expression_with_alias( + rewritten_expr, column_name + ) + if column_name not in column_order: + column_order.append(column_name) + + # Build final fused project + # Only include star() if upstream also had it (preserving selection semantics) + if upstream_has_all: + # Upstream preserves existing: fused result should too + fused_exprs = [StarColumnsExpr()] + [ + column_definitions[name] for name in column_order + ] + else: + # Upstream is a selection: fused result should only have explicit columns + fused_exprs = [column_definitions[name] for name in column_order] + + return Project( + upstream_project.input_dependency, + exprs=fused_exprs, + ray_remote_args=downstream_project._ray_remote_args, + ) + - # Verify that remapping only references columns available in the projection - if projection_cols is not None: - invalid_cols = [key for key in remap.keys() if key not in projection_cols] +class ProjectionPushdown(Rule): + """ + Optimization rule that pushes projections (column selections) down the query plan. + + This rule performs two optimizations: + 1. Fuses consecutive Project operations to eliminate redundant projections + 2. Pushes projections into data sources (e.g., Read operations) to enable + column pruning at the storage layer + """ - if invalid_cols: - raise ValueError( - f"Identified projections with invalid rename " - f"columns: {', '.join(invalid_cols)}" + def apply(self, plan: LogicalPlan) -> LogicalPlan: + """Apply projection pushdown optimization to the entire plan.""" + dag = plan.dag + new_dag = dag._apply_transform(self._optimize_project) + return LogicalPlan(new_dag, plan.context) if dag is not new_dag else plan + + @classmethod + def _optimize_project(cls, op: LogicalOperator) -> LogicalOperator: + """ + Optimize a single Project operator. + + Steps: + 1. Iteratively fuse with upstream Project operations + 2. Push the resulting projection into the data source if possible + """ + if not isinstance(op, Project): + return op + + # Step 1: Iteratively fuse with upstream Project operations + current_project: Project = op + while isinstance(current_project.input_dependency, Project): + upstream_project: Project = current_project.input_dependency # type: ignore[assignment] + fused_project = _try_fuse_consecutive_projects( + upstream_project, current_project ) + if fused_project is None: + # Fusion not possible, stop iterating + break + current_project = fused_project + + # Step 2: Push projection into the data source if supported + # For example, when reading Parquet files, we can pass column names + # to only read the required columns. + input_op = current_project.input_dependency + if ( + not current_project.has_all_columns_expr() # Must be a selection, not additive + and isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) + and input_op.supports_projection_pushdown() + ): + required_columns = _collect_referenced_columns(list(current_project.exprs)) + if required_columns is not None: # None means star() was present + optimized_source = input_op.apply_projection(sorted(required_columns)) + return optimized_source + + return current_project diff --git a/python/ray/data/_internal/pandas_block.py b/python/ray/data/_internal/pandas_block.py index 92ad48ea50b1..9d509ddc44c9 100644 --- a/python/ray/data/_internal/pandas_block.py +++ b/python/ray/data/_internal/pandas_block.py @@ -626,8 +626,9 @@ def filter(self, predicate_expr: "Expr") -> "pandas.DataFrame": if self._table.empty: return self._table - # TODO: Move _expression_evaluator to _internal - from ray.data._expression_evaluator import eval_expr + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + eval_expr, + ) # Evaluate the expression to get a boolean mask mask = eval_expr(predicate_expr, self._table) diff --git a/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py b/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py index 48c4c9f84a54..968a7b7bb5d3 100644 --- a/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/shuffle_task_spec.py @@ -93,7 +93,12 @@ def map( # Build a list of slices to return. It's okay to put the results in a # list instead of yielding them as a generator because slicing the # ArrowBlock is zero-copy. - slice_sz = max(1, math.ceil(block.num_rows() / output_num_blocks)) + total_num_rows = block.num_rows() + slice_sz = ( + math.ceil(total_num_rows / output_num_blocks) + if total_num_rows > 0 and output_num_blocks > 0 + else 0 + ) slices = [] for i in range(output_num_blocks): slices.append(block.slice(i * slice_sz, (i + 1) * slice_sz)) @@ -104,8 +109,10 @@ def map( random = np.random.RandomState(seed_i) random.shuffle(slices) + # Verify row count consistency num_rows = sum(BlockAccessor.for_block(s).num_rows() for s in slices) - assert num_rows == block.num_rows(), (num_rows, block.num_rows()) + assert num_rows == total_num_rows, (num_rows, total_num_rows) + from ray.data.block import BlockMetadataWithSchema meta = block.get_metadata(exec_stats=stats.build()) diff --git a/python/ray/data/_internal/planner/plan_download_op.py b/python/ray/data/_internal/planner/plan_download_op.py index a0c2ba776c29..126c54305eb4 100644 --- a/python/ray/data/_internal/planner/plan_download_op.py +++ b/python/ray/data/_internal/planner/plan_download_op.py @@ -148,6 +148,9 @@ def uri_to_path(uri: str) -> str: def _arrow_batcher(table: pa.Table, output_batch_size: int): """Batch a PyArrow table into smaller tables of size n using zero-copy slicing.""" num_rows = table.num_rows + if output_batch_size == 0: + yield table + return for i in range(0, num_rows, output_batch_size): end_idx = min(i + output_batch_size, num_rows) # Use PyArrow's zero-copy slice operation diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index df2a1c066c3c..a2fb57a1dda3 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -1,15 +1,140 @@ +from __future__ import annotations + import ast import logging +import operator +from typing import Any, Callable, Dict, List, Union +import numpy as np +import pandas as pd import pyarrow as pa import pyarrow.compute as pc import pyarrow.dataset as ds -from ray.data.expressions import ColumnExpr, Expr +from ray.data.block import Block, BlockAccessor, BlockColumn +from ray.data.expressions import ( + AliasExpr, + BinaryExpr, + ColumnExpr, + Expr, + LiteralExpr, + Operation, + StarColumnsExpr, + UDFExpr, + UnaryExpr, +) logger = logging.getLogger(__name__) +def _pa_is_in(left: Any, right: Any) -> Any: + if not isinstance(right, (pa.Array, pa.ChunkedArray)): + right = pa.array(right.as_py() if isinstance(right, pa.Scalar) else right) + return pc.is_in(left, right) + + +_PANDAS_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { + Operation.ADD: operator.add, + Operation.SUB: operator.sub, + Operation.MUL: operator.mul, + Operation.DIV: operator.truediv, + Operation.FLOORDIV: operator.floordiv, + Operation.GT: operator.gt, + Operation.LT: operator.lt, + Operation.GE: operator.ge, + Operation.LE: operator.le, + Operation.EQ: operator.eq, + Operation.NE: operator.ne, + Operation.AND: operator.and_, + Operation.OR: operator.or_, + Operation.NOT: operator.not_, + Operation.IS_NULL: pd.isna, + Operation.IS_NOT_NULL: pd.notna, + Operation.IN: lambda left, right: left.is_in(right), + Operation.NOT_IN: lambda left, right: ~left.is_in(right), +} + + +def _is_pa_string_type(t: pa.DataType) -> bool: + return pa.types.is_string(t) or pa.types.is_large_string(t) + + +def _is_pa_string_like(x: Union[pa.Array, pa.ChunkedArray]) -> bool: + t = x.type + if pa.types.is_dictionary(t): + t = t.value_type + return _is_pa_string_type(t) + + +def _pa_decode_dict_string_array(x: Union[pa.Array, pa.ChunkedArray]) -> Any: + """Convert Arrow dictionary-encoded string arrays to regular string arrays. + + Dictionary encoding stores strings as indices into a dictionary of unique values. + This function converts them back to regular string arrays for string operations. + + Example: + # Input: pa.array(['a', 'b']).dictionary_encode() + # -- dictionary: ["a", "b"] + # -- indices: [0, 1] + # Output: regular string array ["a", "b"] + Args: + x: The input array to convert. + Returns: + The converted string array. + """ + if pa.types.is_dictionary(x.type) and _is_pa_string_type(x.type.value_type): + return pc.cast(x, pa.string()) + return x + + +def _to_pa_string_input(x: Any) -> Any: + if isinstance(x, str): + return pa.scalar(x) + elif _is_pa_string_like(x) and isinstance(x, (pa.Array, pa.ChunkedArray)): + x = _pa_decode_dict_string_array(x) + else: + raise + return x + + +def _pa_add_or_concat(left: Any, right: Any) -> Any: + # If either side is string-like, perform string concatenation. + if ( + isinstance(left, str) + or isinstance(right, str) + or (isinstance(left, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(left)) + or ( + isinstance(right, (pa.Array, pa.ChunkedArray)) and _is_pa_string_like(right) + ) + ): + left_input = _to_pa_string_input(left) + right_input = _to_pa_string_input(right) + return pc.binary_join_element_wise(left_input, right_input, "") + return pc.add(left, right) + + +_ARROW_EXPR_OPS_MAP: Dict[Operation, Callable[..., Any]] = { + Operation.ADD: _pa_add_or_concat, + Operation.SUB: pc.subtract, + Operation.MUL: pc.multiply, + Operation.DIV: pc.divide, + Operation.FLOORDIV: lambda left, right: pc.floor(pc.divide(left, right)), + Operation.GT: pc.greater, + Operation.LT: pc.less, + Operation.GE: pc.greater_equal, + Operation.LE: pc.less_equal, + Operation.EQ: pc.equal, + Operation.NE: pc.not_equal, + Operation.AND: pc.and_kleene, + Operation.OR: pc.or_kleene, + Operation.NOT: pc.invert, + Operation.IS_NULL: pc.is_null, + Operation.IS_NOT_NULL: pc.is_valid, + Operation.IN: _pa_is_in, + Operation.NOT_IN: lambda left, right: pc.invert(_pa_is_in(left, right)), +} + + # NOTE: (srinathk) There are 3 distinct stages of handling passed in exprs: # 1. Parsing it (as text) # 2. Resolving unbound names (to schema) @@ -405,3 +530,255 @@ def visit_Call(self, node: ast.Call) -> "Expr": return BinaryExpr(Operation.IN, left, right) else: raise ValueError(f"Unsupported function: {func_name}") + + +class NativeExpressionEvaluator: + """Visitor-based expression evaluator that uses Block and BlockColumns + + This evaluator implements the visitor pattern to traverse expression trees + and evaluate them against Block data structures. It maintains operation + mappings in shared state and returns consistent BlockColumn types. + """ + + def __init__(self, block: Block): + """Initialize the evaluator with a block and operation mappings. + + Args: + block: The Block to evaluate expressions against. + """ + self.block = block + self.block_accessor = BlockAccessor.for_block(block) + + # Use BlockAccessor to determine operation mappings + block_type = self.block_accessor.block_type() + if block_type.value == "pandas": + self.ops = _PANDAS_EXPR_OPS_MAP + elif block_type.value == "arrow": + self.ops = _ARROW_EXPR_OPS_MAP + else: + raise TypeError(f"Unsupported block type: {block_type}") + + def visit(self, expr: Expr) -> Union[BlockColumn, Any]: + """Visit an expression node and return the evaluated result. + + Args: + expr: The expression to evaluate. + + Returns: + The evaluated result as a BlockColumn or scalar value. + """ + if isinstance(expr, ColumnExpr): + return self.visit_column(expr) + elif isinstance(expr, LiteralExpr): + return self.visit_literal(expr) + elif isinstance(expr, BinaryExpr): + return self.visit_binary(expr) + elif isinstance(expr, UnaryExpr): + return self.visit_unary(expr) + elif isinstance(expr, UDFExpr): + return self.visit_udf(expr) + elif isinstance(expr, AliasExpr): + return self.visit_alias(expr) + elif isinstance(expr, StarColumnsExpr): + # star() should not be evaluated directly - it's handled at Project level + raise TypeError( + "StarColumnsExpr cannot be evaluated as a regular expression. " + "It should only be used in Project operations." + ) + else: + raise TypeError(f"Unsupported expression node: {type(expr).__name__}") + + def visit_column(self, expr: ColumnExpr) -> BlockColumn: + """Visit a column expression and return the column data. + + Args: + expr: The column expression. + + Returns: + The column data as a BlockColumn. + """ + return self.block[expr.name] + + def visit_literal(self, expr: LiteralExpr) -> Any: + """Visit a literal expression and return the literal value. + + Args: + expr: The literal expression. + + Returns: + The literal value. + """ + return expr.value + + def visit_binary(self, expr: BinaryExpr) -> BlockColumn: + """Visit a binary expression and return the result of the operation. + + Args: + expr: The binary expression. + + Returns: + The result of the binary operation as a BlockColumn. + """ + left_result = self.visit(expr.left) + right_result = self.visit(expr.right) + + return self.ops[expr.op](left_result, right_result) + + def visit_unary(self, expr: UnaryExpr) -> BlockColumn: + """Visit a unary expression and return the result of the operation. + + Args: + expr: The unary expression. + + Returns: + The result of the unary operation as a BlockColumn. + """ + operand_result = self.visit(expr.operand) + return self.ops[expr.op](operand_result) + + def visit_udf(self, expr: UDFExpr) -> BlockColumn: + """Visit a UDF expression and return the result of the function call. + + Args: + expr: The UDF expression. + + Returns: + The result of the UDF call as a BlockColumn. + """ + args = [self.visit(arg) for arg in expr.args] + kwargs = {k: self.visit(v) for k, v in expr.kwargs.items()} + result = expr.fn(*args, **kwargs) + + # Validate return type + if not isinstance(result, (pd.Series, np.ndarray, pa.Array, pa.ChunkedArray)): + function_name = expr.fn.__name__ + raise TypeError( + f"UDF '{function_name}' returned invalid type {type(result).__name__}. " + f"Expected type (pandas.Series, numpy.ndarray, pyarrow.Array, or pyarrow.ChunkedArray)" + ) + + return result + + def visit_alias(self, expr: AliasExpr) -> Union[BlockColumn, Any]: + """Visit an alias expression and return the renamed result. + + Args: + expr: The alias expression. + + Returns: + A Block with the data from the inner expression. + """ + # Evaluate the inner expression + return self.visit(expr.expr) + + +def eval_expr(expr: Expr, block: Block) -> BlockColumn: + """Evaluate an expression against a block using the visitor pattern. + + Args: + expr: The expression to evaluate. + block: The Block to evaluate against. + + Returns: + The evaluated result as a BlockColumn. + """ + evaluator = NativeExpressionEvaluator(block) + return evaluator.visit(expr) + + +def eval_projection(exprs: List[Expr], block: Block) -> Block: + """ + Evaluate a projection (list of expressions) against a block. + + Handles projection semantics including: + - Empty projections + - Star() expressions for preserving existing columns + - Rename detection + - Column ordering + + Args: + exprs: List of expressions to evaluate (may include StarColumnsExpr) + block: The block to project + + Returns: + A new block with the projected schema + """ + block_accessor = BlockAccessor.for_block(block) + + # Skip projection only for schema-less empty blocks + if block_accessor.num_rows() == 0 and len(block_accessor.column_names()) == 0: + return block + + has_star = any(isinstance(expr, StarColumnsExpr) for expr in exprs) + existing_cols = list(block_accessor.column_names()) + + # Empty projection + if len(exprs) == 0: + # No expressions at all - return empty projection + if block_accessor.num_rows() > 0: + from ray.data._internal.arrow_block import ( + _BATCH_SIZE_PRESERVING_STUB_COL_NAME as _STUB, + ) + + return BlockAccessor.for_block(block).fill_column(_STUB, None) + return block_accessor.select([]) + + # Identity projection: single star() with no other expressions + if len(exprs) == 1 and isinstance(exprs[0], StarColumnsExpr): + return block + + # Phase 1: Compute non-star() expressions + new_output_cols: List[str] = [] + seen_output_names = set() + rename_map: Dict[str, str] = {} + computed_outputs: Dict[str, Any] = {} + + for expr in exprs: + if isinstance(expr, StarColumnsExpr): + continue + + output_name = expr.name + + # Detect simple renames + if ( + isinstance(expr, AliasExpr) + and expr.expr.name != output_name + and expr.expr.name in existing_cols + ): + rename_map[expr.expr.name] = output_name + + computed_outputs[output_name] = eval_expr(expr, block) + if output_name in seen_output_names: + raise ValueError(f"Column name '{output_name}' is a duplicate.") + new_output_cols.append(output_name) + seen_output_names.add(output_name) + + # Phase 2: Upsert computed columns + cur_block = block + for output_name in new_output_cols: + cur_block = BlockAccessor.for_block(cur_block).fill_column( + output_name, computed_outputs[output_name] + ) + + # Phase 3: Finalize output schema based on star() position + if has_star: + final_cols: List[str] = [] + final_seen = set() + + # Preserve existing columns with renames + for col in existing_cols: + renamed_col = rename_map.get(col, col) + if renamed_col not in final_seen: + final_cols.append(renamed_col) + final_seen.add(renamed_col) + + # Append new columns + for col in new_output_cols: + if col not in final_seen: + final_cols.append(col) + final_seen.add(col) + else: + # No star(): only requested outputs + final_cols = new_output_cols + + return BlockAccessor.for_block(cur_block).select(final_cols) diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index 0578881620ca..be6b289564f8 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -24,7 +24,6 @@ import ray from ray._common.utils import get_or_create_event_loop from ray._private.ray_constants import env_integer -from ray.data._expression_evaluator import eval_expr from ray.data._internal.compute import get_compute from ray.data._internal.execution.interfaces import PhysicalOperator from ray.data._internal.execution.interfaces.task_context import TaskContext @@ -113,54 +112,25 @@ def plan_project_op( assert len(physical_children) == 1 input_physical_dag = physical_children[0] - columns = op.cols - columns_rename = op.cols_rename - exprs = op.exprs - def _project_block(block: Block) -> Block: try: - block_accessor = BlockAccessor.for_block(block) - if not block_accessor.num_rows(): - return block - - # 1. evaluate / add expressions - if exprs: - block_accessor = BlockAccessor.for_block(block) - # Add/update with expression results - result_block = block - for name, expr in exprs.items(): - # Use expr.name if available, otherwise fall back to the dict key name - actual_name = expr.name if expr.name is not None else name - result = eval_expr(expr, result_block) - result_block_accessor = BlockAccessor.for_block(result_block) - # fill_column handles both scalars and arrays - result_block = result_block_accessor.fill_column( - actual_name, result - ) - block = result_block - - # 2. (optional) column projection - if columns: - block = BlockAccessor.for_block(block).select(columns) + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + eval_projection, + ) - # 3. (optional) rename - if columns_rename: - block = block.rename_columns( - [columns_rename.get(col, col) for col in block.schema.names] - ) - - return block + return eval_projection(op.exprs, block) except Exception as e: _try_wrap_udf_exception(e) compute = get_compute(op._compute) - map_transformer = MapTransformer( [ - BlockMapTransformFn(_generate_transform_fn_for_map_block(_project_block)), + BlockMapTransformFn( + _generate_transform_fn_for_map_block(_project_block), + disable_block_shaping=(len(op.exprs) == 0), + ) ] ) - return MapOperator.create( map_transformer, input_physical_dag, @@ -504,15 +474,19 @@ def transform_fn( ) -> Iterable[DataBatch]: for batch in batches: try: - if ( - not isinstance(batch, collections.abc.Mapping) - and BlockAccessor.for_block(batch).num_rows() == 0 - ): - # For empty input blocks, we directly output them without - # calling the UDF. - # TODO(hchen): This workaround is because some all-to-all - # operators output empty blocks with no schema. - res = [batch] + if not isinstance(batch, collections.abc.Mapping): + acc = BlockAccessor.for_block(batch) + # Only skip UDF for truly schema-less empty blocks + if acc.num_rows() == 0 and len(acc.column_names()) == 0: + # For empty input blocks, we directly output them without + # calling the UDF. + # TODO(hchen): This workaround is because some all-to-all + # operators output empty blocks with no schema. + res = [batch] + else: + res = fn(batch) + if not isinstance(res, GeneratorType): + res = [res] else: res = fn(batch) if not isinstance(res, GeneratorType): diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 55cbdf7d5195..45ca1867f637 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -136,7 +136,7 @@ from ray.data._internal.execution.interfaces import Executor, NodeIdStr from ray.data.grouped_data import GroupedData -from ray.data.expressions import Expr +from ray.data.expressions import Expr, star as rd_star logger = logging.getLogger(__name__) @@ -834,9 +834,7 @@ def with_column( else: project_op = Project( self._logical_plan.dag, - cols=None, - cols_rename=None, - exprs={column_name: expr}, + exprs=[rd_star(), expr.alias(column_name)], ray_remote_args=ray_remote_args, ) logical_plan = LogicalPlan(project_op, self.context) @@ -1063,24 +1061,25 @@ def select_columns( Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. """ # noqa: E501 + from ray.data.expressions import col + if isinstance(cols, str): - cols = [cols] + exprs = [col(cols)] elif isinstance(cols, list): if not all(isinstance(col, str) for col in cols): raise ValueError( "select_columns requires all elements of 'cols' to be strings." ) + if len(cols) != len(set(cols)): + raise ValueError( + "select_columns expected unique column names, " + f"got duplicate column names: {cols}" + ) + exprs = [col(c) for c in cols] else: raise TypeError( "select_columns requires 'cols' to be a string or a list of strings." ) - - if len(cols) != len(set(cols)): - raise ValueError( - "select_columns expected unique column names, " - f"got duplicate column names: {cols}" - ) - # Don't feel like we really need this from ray.data._internal.compute import TaskPoolStrategy @@ -1089,8 +1088,7 @@ def select_columns( plan = self._plan.copy() select_op = Project( self._logical_plan.dag, - cols=cols, - cols_rename=None, + exprs=exprs, compute=compute, ray_remote_args=ray_remote_args, ) @@ -1152,7 +1150,9 @@ def rename_columns( Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. """ # noqa: E501 + from ray.data.expressions import col + exprs = [] if isinstance(names, dict): if not names: raise ValueError("rename_columns received 'names' with no entries.") @@ -1170,7 +1170,7 @@ def rename_columns( "to be strings." ) - cols_rename = names + exprs = [col(old).alias(new) for old, new in names.items()] elif isinstance(names, list): if not names: raise ValueError( @@ -1194,7 +1194,10 @@ def rename_columns( f"schema names: {current_names}." ) - cols_rename = dict(zip(current_names, names)) + exprs = [ + col(old).alias(new) + for old, new in dict(zip(current_names, names)).items() + ] else: raise TypeError( f"rename_columns expected names to be either List[str] or " @@ -1215,8 +1218,7 @@ def rename_columns( plan = self._plan.copy() select_op = Project( self._logical_plan.dag, - cols=None, - cols_rename=cols_rename, + exprs=[rd_star()] + exprs, compute=compute, ray_remote_args=ray_remote_args, ) @@ -3584,7 +3586,7 @@ def count(self) -> int: # NOTE: Project the dataset to avoid the need to carry actual # data when we're only interested in the total count - count_op = Count(Project(self._logical_plan.dag, cols=[])) + count_op = Count(Project(self._logical_plan.dag, exprs=[])) logical_plan = LogicalPlan(count_op, self.context) count_ds = Dataset(plan, logical_plan) @@ -4006,8 +4008,8 @@ def write_json( concurrency=concurrency, ) - @ConsumptionAPI @PublicAPI(stability="alpha", api_group=IOC_API_GROUP) + @ConsumptionAPI def write_iceberg( self, table_identifier: str, diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index c74005f55dcf..4c90aee16415 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -147,7 +147,9 @@ def visit_binary(self, expr: "BinaryExpr") -> "pyarrow.compute.Expression": left = self.visit(expr.left) right = self.visit(expr.right) - from ray.data._expression_evaluator import _ARROW_EXPR_OPS_MAP + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + _ARROW_EXPR_OPS_MAP, + ) if expr.op in _ARROW_EXPR_OPS_MAP: return _ARROW_EXPR_OPS_MAP[expr.op](left, right) @@ -155,7 +157,9 @@ def visit_binary(self, expr: "BinaryExpr") -> "pyarrow.compute.Expression": def visit_unary(self, expr: "UnaryExpr") -> "pyarrow.compute.Expression": operand = self.visit(expr.operand) - from ray.data._expression_evaluator import _ARROW_EXPR_OPS_MAP + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + _ARROW_EXPR_OPS_MAP, + ) if expr.op in _ARROW_EXPR_OPS_MAP: return _ARROW_EXPR_OPS_MAP[expr.op](operand) @@ -493,7 +497,7 @@ class UnaryExpr(Expr): op: Operation operand: Expr - data_type: DataType = field(init=False) + data_type: DataType = field(default_factory=lambda: DataType(object), init=False) def structurally_equals(self, other: Any) -> bool: return ( @@ -681,6 +685,29 @@ def structurally_equals(self, other: Any) -> bool: ) +@DeveloperAPI(stability="alpha") +@dataclass(frozen=True, eq=False) +class StarColumnsExpr(Expr): + """Expression that represents all columns from the input. + + This is a special expression used in projections to indicate that + all existing columns should be preserved at this position in the output. + It's typically used internally by operations like with_column() and + rename_columns() to maintain existing columns. + + Example: + When with_column("new_col", expr) is called, it creates: + Project(exprs=[star(), expr.alias("new_col")]) + + This means: keep all existing columns, then add/overwrite "new_col" + """ + + data_type: DataType = field(default_factory=lambda: DataType(object), init=False) + + def structurally_equals(self, other: Any) -> bool: + return isinstance(other, StarColumnsExpr) + + @PublicAPI(stability="beta") def col(name: str) -> ColumnExpr: """ @@ -743,6 +770,21 @@ def lit(value: Any) -> LiteralExpr: return LiteralExpr(value) +@PublicAPI(stability="beta") +def star() -> StarColumnsExpr: + """ + References all input columns from the input. + + This is a special expression used in projections to preserve all + existing columns. It's typically used with operations that want to + add or modify columns while keeping the rest. + + Returns: + A StarColumnsExpr that represents all input columns. + """ + return StarColumnsExpr() + + @DeveloperAPI(stability="alpha") def download(uri_column_name: str) -> DownloadExpr: """ @@ -788,8 +830,10 @@ def download(uri_column_name: str) -> DownloadExpr: "UDFExpr", "DownloadExpr", "AliasExpr", + "StarColumnsExpr", "udf", "col", "lit", "download", + "star", ] diff --git a/python/ray/data/tests/test_dataset_aggregrations.py b/python/ray/data/tests/test_dataset_aggregrations.py index 34bbb2431394..fd416d5931bd 100644 --- a/python/ray/data/tests/test_dataset_aggregrations.py +++ b/python/ray/data/tests/test_dataset_aggregrations.py @@ -33,7 +33,7 @@ def test_count_edge_case(ray_start_regular): ds = ray.data.range(10) ds.count() - actual_count = ds.filter(lambda row: row["id"] % 2 == 0).count() + actual_count = ds.filter(fn=lambda row: row["id"] % 2 == 0).count() assert actual_count == 5 diff --git a/python/ray/data/tests/test_execution_optimizer_basic.py b/python/ray/data/tests/test_execution_optimizer_basic.py index 959d5d0c4af6..cc52bd9771eb 100644 --- a/python/ray/data/tests/test_execution_optimizer_basic.py +++ b/python/ray/data/tests/test_execution_optimizer_basic.py @@ -31,6 +31,7 @@ from ray.data.context import DataContext from ray.data.datasource import Datasource from ray.data.datasource.datasource import ReadTask +from ray.data.expressions import col from ray.data.tests.conftest import * # noqa from ray.data.tests.test_util import _check_usage_record, get_parquet_read_logical_op from ray.data.tests.util import column_udf, extract_values, named_values @@ -283,7 +284,7 @@ def test_project_operator_select(ray_start_regular_shared_2_cpus): logical_plan = ds._plan._logical_plan op = logical_plan.dag assert isinstance(op, Project), op.name - assert op.cols == cols + assert op.exprs == [col("sepal.length"), col("petal.width")] physical_plan = create_planner().plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) @@ -297,6 +298,8 @@ def test_project_operator_rename(ray_start_regular_shared_2_cpus): Checks that the physical plan is properly generated for the Project operator from rename columns. """ + from ray.data.expressions import star as rd_star + path = "example://iris.parquet" ds = ray.data.read_parquet(path) ds = ds.map_batches(lambda d: d) @@ -306,8 +309,11 @@ def test_project_operator_rename(ray_start_regular_shared_2_cpus): logical_plan = ds._plan._logical_plan op = logical_plan.dag assert isinstance(op, Project), op.name - assert not op.cols - assert op.cols_rename == cols_rename + assert op.exprs == [ + rd_star(), + col("sepal.length").alias("sepal_length"), + col("petal.width").alias("pedal_width"), + ] physical_plan = create_planner().plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) diff --git a/python/ray/data/tests/test_expressions.py b/python/ray/data/tests/test_expressions.py index 29b429056682..82a35382e22c 100644 --- a/python/ray/data/tests/test_expressions.py +++ b/python/ray/data/tests/test_expressions.py @@ -66,7 +66,9 @@ def test_alias_functionality(expr, alias_name, expected_alias): """Test alias functionality with various expression types.""" import pandas as pd - from ray.data._expression_evaluator import eval_expr + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + eval_expr, + ) # Test alias creation aliased_expr = expr.alias(alias_name) diff --git a/python/ray/data/tests/test_operator_fusion.py b/python/ray/data/tests/test_operator_fusion.py index 28db544c9f04..8f43c669a6e0 100644 --- a/python/ray/data/tests/test_operator_fusion.py +++ b/python/ray/data/tests/test_operator_fusion.py @@ -26,6 +26,7 @@ from ray.data._internal.stats import DatasetStats from ray.data.context import DataContext from ray.data.dataset import Dataset +from ray.data.expressions import star as rd_star from ray.data.tests.conftest import * # noqa from ray.data.tests.test_util import _check_usage_record, get_parquet_read_logical_op from ray.data.tests.util import column_udf, extract_values @@ -66,7 +67,7 @@ def test_read_map_chain_operator_fusion(ray_start_regular_shared_2_cpus): map1 = MapRows(read_op, lambda x: x) map2 = MapBatches(map1, lambda x: x) map3 = FlatMap(map2, lambda x: x) - map4 = Filter(map3, lambda x: x) + map4 = Filter(map3, fn=lambda x: x) logical_plan = LogicalPlan(map4, ctx) physical_plan = planner.plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) @@ -296,7 +297,7 @@ def test_read_with_map_batches_fused_successfully( ), ( # No fusion (could drastically reduce dataset) - Filter(InputData([]), lambda x: False), + Filter(InputData([]), fn=lambda x: False), False, ), ( @@ -316,7 +317,7 @@ def test_read_with_map_batches_fused_successfully( ), ( # Fusion - Project(InputData([])), + Project(InputData([]), exprs=[rd_star()]), True, ), ], @@ -429,7 +430,8 @@ def fn(batch): ds = ds.map_batches(fn, batch_size=None) assert set(extract_values("id", ds.take_all())) == set(range(1, n + 1)) assert "ReadRange->MapBatches(fn)->RandomizeBlockOrder" not in ds.stats() - assert "ReadRange->MapBatches(fn)" in ds.stats() + # ReadRange cannot fuse with MapBatches due to RandomizeBlockOrder in between + assert "ReadRange->MapBatches(fn)" not in ds.stats() _check_usage_record(["ReadRange", "MapBatches", "RandomizeBlockOrder"]) @@ -575,7 +577,7 @@ def test_read_map_chain_operator_fusion_e2e( ray_start_regular_shared_2_cpus, ): ds = ray.data.range(10, override_num_blocks=2) - ds = ds.filter(lambda x: x["id"] % 2 == 0) + ds = ds.filter(fn=lambda x: x["id"] % 2 == 0) ds = ds.map(column_udf("id", lambda x: x + 1)) ds = ds.map_batches( lambda batch: {"id": [2 * x for x in batch["id"]]}, batch_size=None diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 3e3047691cd2..a04e9a78e6eb 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -635,7 +635,6 @@ def test_projection_pushdown_non_partitioned(ray_start_regular_shared, temp_dir) summary = ds.materialize()._plan.stats().to_summary() - assert "ReadParquet" in summary.base_name assert summary.extra_metrics["bytes_task_outputs_generated"] == 0 diff --git a/python/ray/data/tests/test_projection_fusion.py b/python/ray/data/tests/test_projection_fusion.py index 9855e93e31ab..77be05ffcc97 100644 --- a/python/ray/data/tests/test_projection_fusion.py +++ b/python/ray/data/tests/test_projection_fusion.py @@ -14,7 +14,7 @@ ProjectionPushdown, ) from ray.data.context import DataContext -from ray.data.expressions import DataType, col, udf +from ray.data.expressions import DataType, col, star, udf @dataclass @@ -38,7 +38,7 @@ class DependencyTestCase: description: str -class TestPorjectionFusion: +class TestProjectionFusion: """Test topological sorting in projection pushdown fusion.""" @pytest.fixture(autouse=True) @@ -114,12 +114,14 @@ def _create_project_chain(self, input_op, expressions_list: List[Dict[str, str]] current_op = input_op for expr_dict in expressions_list: - exprs = { - name: self._parse_expression(desc) for name, desc in expr_dict.items() - } - current_op = Project( - current_op, cols=None, cols_rename=None, exprs=exprs, ray_remote_args={} - ) + # Convert dictionary to list of named expressions + exprs = [] + for name, desc in expr_dict.items(): + expr = self._parse_expression(desc) + named_expr = expr.alias(name) + exprs.append(named_expr) + + current_op = Project(current_op, exprs=[star()] + exprs, ray_remote_args={}) return current_op @@ -130,7 +132,8 @@ def _extract_levels_from_plan(self, plan: LogicalPlan) -> List[Set[str]]: while isinstance(current, Project): if current.exprs: - levels.append(set(current.exprs.keys())) + # Extract names from list of expressions instead of dictionary keys + levels.append({expr.name for expr in current.exprs}) current = current.input_dependency return list(reversed(levels)) # Return bottom-up order @@ -604,7 +607,7 @@ def test_chained_udf_dependencies(self): assert self._count_project_operators(optimized_plan) == 1 assert ( self._describe_plan_structure(optimized_plan) - == "Project(3 exprs) -> FromItems" # Changed from multiple operators + == "Project(4 exprs) -> FromItems" # Changed from multiple operators ) # Verify execution correctness @@ -666,13 +669,375 @@ def test_performance_impact_of_udf_chains(self): ) # Changed from 3 to 1 assert ( self._describe_plan_structure(optimized_independent) - == "Project(3 exprs) -> FromItems" + == "Project(4 exprs) -> FromItems" ) assert ( self._describe_plan_structure(optimized_chained) - == "Project(3 exprs) -> FromItems" # Changed from multiple operators + == "Project(4 exprs) -> FromItems" # Changed from multiple operators + ) + + @pytest.mark.parametrize( + "operations,expected", + [ + # Single operations + ([("rename", {"a": "A"})], {"A": 1, "b": 2, "c": 3}), + ([("select", ["a", "b"])], {"a": 1, "b": 2}), + ([("with_column", "d", 4)], {"a": 1, "b": 2, "c": 3, "d": 4}), + # Two operations - rename then select + ([("rename", {"a": "A"}), ("select", ["A"])], {"A": 1}), + ([("rename", {"a": "A"}), ("select", ["b"])], {"b": 2}), + ( + [("rename", {"a": "A", "b": "B"}), ("select", ["A", "B"])], + {"A": 1, "B": 2}, + ), + # Two operations - select then rename + ([("select", ["a", "b"]), ("rename", {"a": "A"})], {"A": 1, "b": 2}), + ([("select", ["a"]), ("rename", {"a": "x"})], {"x": 1}), + # Two operations - with_column combinations + ([("with_column", "d", 4), ("select", ["a", "d"])], {"a": 1, "d": 4}), + ([("select", ["a"]), ("with_column", "d", 4)], {"a": 1, "d": 4}), + ( + [("rename", {"a": "A"}), ("with_column", "d", 4)], + {"A": 1, "b": 2, "c": 3, "d": 4}, + ), + ( + [("with_column", "d", 4), ("rename", {"d": "D"})], + {"a": 1, "b": 2, "c": 3, "D": 4}, + ), + # Three operations + ( + [ + ("rename", {"a": "A"}), + ("select", ["A", "b"]), + ("with_column", "d", 4), + ], + {"A": 1, "b": 2, "d": 4}, + ), + ( + [ + ("with_column", "d", 4), + ("rename", {"a": "A"}), + ("select", ["A", "d"]), + ], + {"A": 1, "d": 4}, + ), + ( + [ + ("select", ["a", "b"]), + ("rename", {"a": "x"}), + ("with_column", "d", 4), + ], + {"x": 1, "b": 2, "d": 4}, + ), + # Column swap + ([("rename", {"a": "b", "b": "a"}), ("select", ["a"])], {"a": 2}), + ([("rename", {"a": "b", "b": "a"}), ("select", ["b"])], {"b": 1}), + # Multiple same operations + ( + [("rename", {"a": "x"}), ("rename", {"x": "y"})], + {"y": 1, "b": 2, "c": 3}, + ), + ([("select", ["a", "b"]), ("select", ["a"])], {"a": 1}), + ( + [("with_column", "d", 4), ("with_column", "e", 5)], + {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}, + ), + # Complex expressions with with_column + ( + [("rename", {"a": "x"}), ("with_column_expr", "sum", "x", 10)], + {"x": 1, "b": 2, "c": 3, "sum": 10}, + ), + ( + [ + ("with_column", "d", 4), + ("with_column", "e", 5), + ("select", ["d", "e"]), + ], + {"d": 4, "e": 5}, + ), + ], + ) + def test_projection_operations_comprehensive(self, operations, expected): + """Comprehensive test for projection operations combinations.""" + from ray.data.expressions import col, lit + + # Create initial dataset + ds = ray.data.range(1).map(lambda row: {"a": 1, "b": 2, "c": 3}) + + # Apply operations + for op in operations: + if op[0] == "rename": + ds = ds.rename_columns(op[1]) + elif op[0] == "select": + ds = ds.select_columns(op[1]) + elif op[0] == "with_column": + ds = ds.with_column(op[1], lit(op[2])) + elif op[0] == "with_column_expr": + # Special case for expressions referencing columns + ds = ds.with_column(op[1], col(op[2]) * op[3]) + + # Verify result + result = ds.take_all() + assert len(result) == 1 + assert result[0] == expected + + @pytest.mark.parametrize( + "operations,expected", + [ + # Basic count operations + ([("count",)], 3), # All 3 rows + ([("rename", {"a": "A"}), ("count",)], 3), + ([("select", ["a", "b"]), ("count",)], 3), + ([("with_column", "d", 4), ("count",)], 3), + # Filter operations affecting count + ([("filter", col("a") > 1), ("count",)], 2), # 2 rows have a > 1 + ([("filter", col("b") == 2), ("count",)], 3), # All rows have b == 2 + ([("filter", col("c") < 10), ("count",)], 3), # All rows have c < 10 + ([("filter", col("a") == 1), ("count",)], 1), # 1 row has a == 1 + # Projection then filter then count + ([("rename", {"a": "A"}), ("filter", col("A") > 1), ("count",)], 2), + ([("select", ["a", "b"]), ("filter", col("a") > 1), ("count",)], 2), + ([("with_column", "d", 4), ("filter", col("d") == 4), ("count",)], 3), + # Filter then projection then count + ([("filter", col("a") > 1), ("rename", {"a": "A"}), ("count",)], 2), + ([("filter", col("b") == 2), ("select", ["a", "b"]), ("count",)], 3), + ([("filter", col("c") < 10), ("with_column", "d", 4), ("count",)], 3), + # Multiple projections with filter and count + ( + [ + ("rename", {"a": "A"}), + ("select", ["A", "b"]), + ("filter", col("A") > 1), + ("count",), + ], + 2, + ), + ( + [ + ("with_column", "d", 4), + ("rename", {"d": "D"}), + ("filter", col("D") == 4), + ("count",), + ], + 3, + ), + ( + [ + ("select", ["a", "b"]), + ("filter", col("a") > 1), + ("rename", {"a": "x"}), + ("count",), + ], + 2, + ), + # Complex combinations + ( + [ + ("filter", col("a") > 0), + ("rename", {"b": "B"}), + ("select", ["a", "B"]), + ("filter", col("B") == 2), + ("count",), + ], + 3, + ), + ( + [ + ("with_column", "sum", 99), + ("filter", col("a") > 1), + ("select", ["a", "sum"]), + ("count",), + ], + 2, + ), + ( + [ + ("rename", {"a": "A", "b": "B"}), + ("filter", (col("A") + col("B")) > 3), + ("select", ["A"]), + ("count",), + ], + 2, + ), + ], + ) + def test_projection_fusion_with_count_and_filter(self, operations, expected): + """Test projection fusion with count operations including filters.""" + from ray.data.expressions import lit + + # Create dataset with 3 rows: {"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 2, "c": 3}, {"a": 3, "b": 2, "c": 3} + ds = ray.data.from_items( + [ + {"a": 1, "b": 2, "c": 3}, + {"a": 2, "b": 2, "c": 3}, + {"a": 3, "b": 2, "c": 3}, + ] + ) + + # Apply operations + for op in operations: + if op[0] == "rename": + ds = ds.rename_columns(op[1]) + elif op[0] == "select": + ds = ds.select_columns(op[1]) + elif op[0] == "with_column": + ds = ds.with_column(op[1], lit(op[2])) + elif op[0] == "filter": + # Use the predicate expression directly + ds = ds.filter(expr=op[1]) + elif op[0] == "count": + # Count returns a scalar, not a dataset + result = ds.count() + assert result == expected + return # Early return since count() terminates the pipeline + + # This should not be reached for count operations + assert False, "Count operation should have returned early" + + @pytest.mark.parametrize( + "invalid_operations,error_type,error_message_contains", + [ + # Try to filter on a column that doesn't exist yet + ( + [("filter", col("d") > 0), ("with_column", "d", 4)], + (KeyError, ray.exceptions.RayTaskError), + "d", + ), + # Try to filter on a renamed column before the rename + ( + [("filter", col("A") > 1), ("rename", {"a": "A"})], + (KeyError, ray.exceptions.RayTaskError), + "A", + ), + # Try to use a column that was removed by select + ( + [("select", ["a"]), ("filter", col("b") == 2)], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to filter on a column after it was removed by select + ( + [("select", ["a", "b"]), ("filter", col("c") < 10)], + (KeyError, ray.exceptions.RayTaskError), + "c", + ), + # Try to use with_column referencing a non-existent column + ( + [("select", ["a"]), ("with_column", "new_col", col("b") + 1)], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to filter on a column that was renamed away + ( + [("rename", {"b": "B"}), ("filter", col("b") == 2)], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to use with_column with old column name after rename + ( + [("rename", {"a": "A"}), ("with_column", "result", col("a") + 1)], + (KeyError, ray.exceptions.RayTaskError), + "a", + ), + # Try to select using old column name after rename + ( + [("rename", {"b": "B"}), ("select", ["a", "b", "c"])], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to filter on a computed column that was removed by select + ( + [ + ("with_column", "d", 4), + ("select", ["a", "b"]), + ("filter", col("d") == 4), + ], + (KeyError, ray.exceptions.RayTaskError), + "d", + ), + # Try to rename a column that was removed by select + ( + [("select", ["a", "b"]), ("rename", {"c": "C"})], + (KeyError, ray.exceptions.RayTaskError), + "c", + ), + # Complex: rename, select (removing renamed source), then use old name + ( + [ + ("rename", {"a": "A"}), + ("select", ["b", "c"]), + ("filter", col("a") > 0), + ], + (KeyError, ray.exceptions.RayTaskError), + "a", + ), + # Complex: with_column, select (keeping new column), filter on removed original + ( + [ + ("with_column", "sum", col("a") + col("b")), + ("select", ["sum"]), + ("filter", col("a") > 0), + ], + (KeyError, ray.exceptions.RayTaskError), + "a", + ), + # Try to use column in with_column expression after it was removed + ( + [ + ("select", ["a", "c"]), + ("with_column", "result", col("a") + col("b")), + ], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + ], + ) + def test_projection_operations_invalid_order( + self, invalid_operations, error_type, error_message_contains + ): + """Test that operations fail gracefully when referencing non-existent columns.""" + import ray + from ray.data.expressions import lit + + # Create dataset with 3 rows: {"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 2, "c": 3}, {"a": 3, "b": 2, "c": 3} + ds = ray.data.from_items( + [ + {"a": 1, "b": 2, "c": 3}, + {"a": 2, "b": 2, "c": 3}, + {"a": 3, "b": 2, "c": 3}, + ] ) + # Apply operations and expect them to fail + with pytest.raises(error_type) as exc_info: + for op in invalid_operations: + if op[0] == "rename": + ds = ds.rename_columns(op[1]) + elif op[0] == "select": + ds = ds.select_columns(op[1]) + elif op[0] == "with_column": + if len(op) == 3 and not isinstance(op[2], (int, float, str)): + # Expression-based with_column (op[2] is an expression) + ds = ds.with_column(op[1], op[2]) + else: + # Literal-based with_column + ds = ds.with_column(op[1], lit(op[2])) + elif op[0] == "filter": + ds = ds.filter(expr=op[1]) + elif op[0] == "count": + ds.count() + return + + # Force execution to trigger the error + result = ds.take_all() + print(f"Unexpected success: {result}") + + # Verify the error message contains the expected column name + error_str = str(exc_info.value).lower() + assert ( + error_message_contains.lower() in error_str + ), f"Expected '{error_message_contains}' in error message: {error_str}" + if __name__ == "__main__": pytest.main([__file__, "-v"])