diff --git a/python/ray/data/_internal/arrow_block.py b/python/ray/data/_internal/arrow_block.py index 9370398afe07..ec7ff0510f23 100644 --- a/python/ray/data/_internal/arrow_block.py +++ b/python/ray/data/_internal/arrow_block.py @@ -58,6 +58,7 @@ _MIN_PYARROW_VERSION_TO_NUMPY_ZERO_COPY_ONLY = parse_version("13.0.0") +_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub" # Set the max chunk size in bytes for Arrow to Batches conversion in @@ -221,7 +222,7 @@ def fill_column(self, name: str, value: Any) -> Block: array = pyarrow.nulls(len(self._table), type=type) array = pc.fill_null(array, value) - return self._table.append_column(name, array) + return self.upsert_column(name, array) @classmethod def from_bytes(cls, data: bytes) -> "ArrowBlockAccessor": @@ -366,6 +367,10 @@ def select(self, columns: List[str]) -> "pyarrow.Table": "Columns must be a list of column name strings when aggregating on " f"Arrow blocks, but got: {columns}." ) + if len(columns) == 0: + # Applicable for count which does an empty projection. + # Pyarrow returns a table with 0 columns and num_rows rows. + return self.fill_column(_BATCH_SIZE_PRESERVING_STUB_COL_NAME, None) return self._table.select(columns) def rename_columns(self, columns_rename: Dict[str, str]) -> "pyarrow.Table": diff --git a/python/ray/data/_internal/datasource/parquet_datasource.py b/python/ray/data/_internal/datasource/parquet_datasource.py index 65d52fc80083..b37be53c0338 100644 --- a/python/ray/data/_internal/datasource/parquet_datasource.py +++ b/python/ray/data/_internal/datasource/parquet_datasource.py @@ -23,7 +23,10 @@ import ray from ray._private.arrow_utils import get_pyarrow_version -from ray.data._internal.arrow_block import ArrowBlockAccessor +from ray.data._internal.arrow_block import ( + _BATCH_SIZE_PRESERVING_STUB_COL_NAME, + ArrowBlockAccessor, +) from ray.data._internal.progress_bar import ProgressBar from ray.data._internal.remote_fn import cached_remote_fn from ray.data._internal.util import ( @@ -104,9 +107,6 @@ PARQUET_ENCODING_RATIO_ESTIMATE_NUM_ROWS = 1024 -_BATCH_SIZE_PRESERVING_STUB_COL_NAME = "__bsp_stub" - - class _ParquetFragment: """This wrapper class is created to avoid utilizing `ParquetFileFragment` original serialization protocol that actually does network RPCs during serialization diff --git a/python/ray/data/_internal/logical/operators/map_operator.py b/python/ray/data/_internal/logical/operators/map_operator.py index 63cf35410237..401dfda8009a 100644 --- a/python/ray/data/_internal/logical/operators/map_operator.py +++ b/python/ray/data/_internal/logical/operators/map_operator.py @@ -7,7 +7,7 @@ from ray.data._internal.logical.interfaces import LogicalOperator from ray.data._internal.logical.operators.one_to_one_operator import AbstractOneToOne from ray.data.block import UserDefinedFunction -from ray.data.expressions import Expr +from ray.data.expressions import Expr, StarExpr from ray.data.preprocessor import Preprocessor logger = logging.getLogger(__name__) @@ -268,16 +268,12 @@ def can_modify_num_rows(self) -> bool: class Project(AbstractMap): - """Logical operator for select_columns.""" + """Logical operator for all Projection Operations.""" def __init__( self, input_op: LogicalOperator, - cols: Optional[List[str]] = None, - cols_rename: Optional[Dict[str, str]] = None, - exprs: Optional[ - Dict[str, "Expr"] - ] = None, # TODO Remove cols and cols_rename and replace them with corresponding exprs + exprs: list["Expr"], compute: Optional[ComputeStrategy] = None, ray_remote_args: Optional[Dict[str, Any]] = None, ): @@ -288,30 +284,23 @@ def __init__( compute=compute, ) self._batch_size = None - self._cols = cols - self._cols_rename = cols_rename self._exprs = exprs self._batch_format = "pyarrow" self._zero_copy_batch = True - if exprs is not None: - # Validate that all values are expressions - for name, expr in exprs.items(): - if not isinstance(expr, Expr): - raise TypeError( - f"Expected Expr for column '{name}', got {type(expr)}" - ) + for expr in self._exprs: + if expr.name is None and not isinstance(expr, StarExpr): + raise TypeError( + "All Project expressions must be named (use .alias(name) or col(name)), " + "or be a star() expression." + ) - @property - def cols(self) -> Optional[List[str]]: - return self._cols - - @property - def cols_rename(self) -> Optional[Dict[str, str]]: - return self._cols_rename + def has_star_expr(self) -> bool: + """Check if this projection contains a star() expression.""" + return any(isinstance(expr, StarExpr) for expr in self._exprs) @property - def exprs(self) -> Optional[Dict[str, "Expr"]]: + def exprs(self) -> List["Expr"]: return self._exprs def can_modify_num_rows(self) -> bool: diff --git a/python/ray/data/_internal/logical/rules/projection_pushdown.py b/python/ray/data/_internal/logical/rules/projection_pushdown.py index df6a4ab6905d..0cb7950a10d5 100644 --- a/python/ray/data/_internal/logical/rules/projection_pushdown.py +++ b/python/ray/data/_internal/logical/rules/projection_pushdown.py @@ -1,6 +1,4 @@ -import logging -from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import List, Optional, Set, Tuple from ray.data._internal.logical.interfaces import ( LogicalOperator, @@ -9,304 +7,300 @@ Rule, ) from ray.data._internal.logical.operators.map_operator import Project -from ray.data._internal.logical.operators.read_operator import Read -from ray.data.expressions import Expr +from ray.data._internal.planner.plan_expression.expression_visitors import ( + _ColumnReferenceCollector, + _ColumnRewriter, +) +from ray.data.expressions import ( + AliasExpr, + ColumnExpr, + Expr, + StarExpr, +) -logger = logging.getLogger(__name__) +def _collect_referenced_columns(exprs: List[Expr]) -> Optional[Set[str]]: + """ + Extract all column names referenced by the given expressions. -@dataclass(frozen=True) -class _ProjectSpec: - cols: Optional[List[str]] - cols_remap: Optional[Dict[str, str]] - exprs: Optional[Dict[str, Expr]] + Recursively traverses expression trees to find all ColumnExpr nodes + and collects their names. + Example: For expression "col1 + col2", returns {"col1", "col2"} + """ + # If any expression is star(), we need all columns + if any(isinstance(expr, StarExpr) for expr in exprs): + # TODO (goutam): Instead of using None to refer to All columns, resolve the AST against the schema. + # https://github.com/ray-project/ray/issues/57720 + return None -class ProjectionPushdown(Rule): - """Optimization rule that pushes down projections across the graph. + collector = _ColumnReferenceCollector() + for expr in exprs or []: + collector.visit(expr) + return collector.referenced_columns - This rule looks for `Project` operators that are immediately - preceded by a `Read` operator and sets the - projected columns on the `Read` operator. - If there are redundant Project operators, it removes the `Project` operator from - the graph. +def _extract_simple_rename(expr: Expr) -> Optional[Tuple[str, str]]: """ + Check if an expression is a simple column rename. - def apply(self, plan: LogicalPlan) -> LogicalPlan: - dag = plan.dag - new_dag = dag._apply_transform(self._pushdown_project) + Returns (source_name, dest_name) if the expression is of form: + col("source").alias("dest") + where source != dest. - return LogicalPlan(new_dag, plan.context) if dag is not new_dag else plan + Returns None for other expression types. + """ + if isinstance(expr, AliasExpr) and isinstance(expr.expr, ColumnExpr): + dest_name = expr.name + source_name = expr.expr.name + if source_name != dest_name: + return source_name, dest_name + return None - @classmethod - def _pushdown_project(cls, op: LogicalOperator) -> LogicalOperator: - if isinstance(op, Project): - # Push-down projections into read op - if cls._supports_projection_pushdown(op): - project_op: Project = op - target_op: LogicalOperatorSupportsProjectionPushdown = ( - op.input_dependency - ) - return cls._try_combine(target_op, project_op) +def _analyze_upstream_project( + upstream_project: Project, +) -> Tuple[Set[str], dict[str, Expr], Set[str]]: + """ + Analyze what the upstream project produces and identifies removed columns. - # Otherwise, fuse projections into a single op - elif isinstance(op.input_dependency, Project): - outer_op: Project = op - inner_op: Project = op.input_dependency + Example: Upstream exprs [col("x").alias("y")] → removed_by_renames = {"x"} if "x" not in output + """ + output_columns = { + expr.name for expr in upstream_project.exprs if not isinstance(expr, StarExpr) + } + column_definitions = { + expr.name: expr + for expr in upstream_project.exprs + if not isinstance(expr, StarExpr) + } - return cls._fuse(inner_op, outer_op) + # Identify columns removed by renames (source not in output) + removed_by_renames: Set[str] = set() + for expr in upstream_project.exprs: + if isinstance(expr, StarExpr): + continue + rename_pair = _extract_simple_rename(expr) + if rename_pair is not None: + source_name, _ = rename_pair + if source_name not in output_columns: + removed_by_renames.add(source_name) + + return output_columns, column_definitions, removed_by_renames + + +def _validate_fusion( + downstream_project: Project, + upstream_has_all: bool, + upstream_output_columns: Set[str], + removed_by_renames: Set[str], +) -> Tuple[bool, Set[str]]: + """ + Validate if fusion is possible without rewriting expressions. - return op + Args: + downstream_project: The downstream Project operator + upstream_has_all: True if the upstream Project has all columns, False otherwise + upstream_output_columns: Set of column names that are available in the upstream Project + removed_by_renames: Set of column names that are removed by renames in the upstream Project - @classmethod - def _supports_projection_pushdown(cls, op: Project) -> bool: - # NOTE: Currently only projecting into Parquet is supported - input_op = op.input_dependency - return ( - isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) - and input_op.supports_projection_pushdown() - ) + Returns: + Tuple of (is_valid, missing_columns) + - is_valid: True if all expressions can be fused, False otherwise + - missing_columns: Set of column names that are referenced but not available - @staticmethod - def _fuse(inner_op: Project, outer_op: Project) -> Project: - # Combine expressions from both operators - combined_exprs = _combine_expressions(inner_op.exprs, outer_op.exprs) - - # Only combine projection specs if there are no expressions - # When expressions are present, they take precedence - if combined_exprs: - # When expressions are present, preserve column operations from outer operation - # The logical order is: expressions first, then column operations - outer_cols = outer_op.cols - outer_cols_rename = outer_op.cols_rename - - # If outer operation has no column operations, fall back to inner operation - if outer_cols is None and outer_cols_rename is None: - outer_cols = inner_op.cols - outer_cols_rename = inner_op.cols_rename - - return Project( - inner_op.input_dependency, - cols=outer_cols, - cols_rename=outer_cols_rename, - exprs=combined_exprs, - # Give precedence to outer operator's ray_remote_args - ray_remote_args={ - **inner_op._ray_remote_args, - **outer_op._ray_remote_args, - }, - ) - else: - # Fall back to original behavior for column-only projections - inner_op_spec = _get_projection_spec(inner_op) - outer_op_spec = _get_projection_spec(outer_op) - - new_spec = _combine_projection_specs( - prev_spec=inner_op_spec, new_spec=outer_op_spec - ) - - return Project( - inner_op.input_dependency, - cols=new_spec.cols, - cols_rename=new_spec.cols_remap, - exprs=None, - ray_remote_args={ - **inner_op._ray_remote_args, - **outer_op._ray_remote_args, - }, - ) - - @staticmethod - def _try_combine( - target_op: LogicalOperatorSupportsProjectionPushdown, - project_op: Project, - ) -> LogicalOperator: - # For now, don't push down expressions into `Read` operators - # Only handle traditional column projections - if project_op.exprs: - # Cannot push expressions into `Read`, return unchanged - return project_op - - target_op_spec = _get_projection_spec(target_op) - project_op_spec = _get_projection_spec(project_op) - - new_spec = _combine_projection_specs( - prev_spec=target_op_spec, new_spec=project_op_spec - ) + Example: Downstream refs "x" but upstream renamed "x" to "y" and dropped "x" + → (False, {"x"}) + """ + missing_columns = set() - logger.debug( - f"Pushing projection down into read operation " - f"projection columns = {new_spec.cols} (before: {target_op_spec.cols}), " - f"remap = {new_spec.cols_remap} (before: {target_op_spec.cols_remap})" + for expr in downstream_project.exprs: + if isinstance(expr, StarExpr): + continue + + referenced_columns = _collect_referenced_columns([expr]) or set() + columns_from_original = referenced_columns - ( + referenced_columns & upstream_output_columns ) - return target_op.apply_projection(new_spec.cols) + # Validate accessibility + if not upstream_has_all and columns_from_original: + # Example: Upstream selects ["a", "b"], Downstream refs "c" → can't fuse + missing_columns.update(columns_from_original) + if any(col in removed_by_renames for col in columns_from_original): + # Example: Upstream renames "x" to "y" (dropping "x"), Downstream refs "x" → can't fuse + removed_cols = { + col for col in columns_from_original if col in removed_by_renames + } + missing_columns.update(removed_cols) -def _combine_expressions( - inner_exprs: Optional[Dict[str, Expr]], outer_exprs: Optional[Dict[str, Expr]] -) -> Optional[Dict[str, Expr]]: - """Combine expressions from two Project operators. + is_valid = len(missing_columns) == 0 + return is_valid, missing_columns - Args: - inner_exprs: Expressions from the inner (upstream) Project operator - outer_exprs: Expressions from the outer (downstream) Project operator - Returns: - Combined dictionary of expressions, or None if no expressions +def _compose_projects( + upstream_project: Project, + downstream_project: Project, + upstream_has_star: bool, +) -> List[Expr]: """ - if not inner_exprs and not outer_exprs: - return None + Compose two Projects when the downstream has star(). + + Strategy: + - Emit a single star() only if the upstream had star() as well. + - Evaluate upstream non-star expressions first, then downstream non-star expressions. + With sequential projection evaluation, downstream expressions can reference + upstream outputs without explicit rewriting. + - Rename-of-computed columns will be dropped from final output by the evaluator + when there's no later explicit mention of the source name. + """ + fused_exprs: List[Expr] = [] - combined = {} + # Include star only if upstream had star; otherwise, don't reintroduce dropped cols. + if upstream_has_star: + fused_exprs.append(StarExpr()) - # Add expressions from inner operator - if inner_exprs: - combined.update(inner_exprs) + # Then upstream non-star expressions in order. + for expr in upstream_project.exprs: + if not isinstance(expr, StarExpr): + fused_exprs.append(expr) - # Add expressions from outer operator - if outer_exprs: - combined.update(outer_exprs) + # Then downstream non-star expressions in order. + for expr in downstream_project.exprs: + if not isinstance(expr, StarExpr): + fused_exprs.append(expr) - return combined if combined else None + return fused_exprs -def _get_projection_spec(op: Union[Project, Read]) -> _ProjectSpec: - assert op is not None +def _try_fuse_consecutive_projects( + upstream_project: Project, downstream_project: Project +) -> Project: + """ + Attempt to merge two consecutive Project operations into one. - if isinstance(op, Project): - return _ProjectSpec( - cols=op.cols, - cols_remap=op.cols_rename, - exprs=op.exprs, - ) - elif isinstance(op, Read): - assert op.supports_projection_pushdown() + Example: Upstream: [star(), col("x").alias("y")], Downstream: [star(), (col("y") + 1).alias("z")] → Fused: [star(), (col("x") + 1).alias("z")] + """ + upstream_has_star: bool = upstream_project.has_star_expr() + downstream_has_star: bool = downstream_project.has_star_expr() + + # Analyze upstream + ( + upstream_output_columns, + upstream_column_definitions, + removed_by_renames, + ) = _analyze_upstream_project(upstream_project) + + # Validate fusion possibility + is_valid, missing_columns = _validate_fusion( + downstream_project, + upstream_has_star, + upstream_output_columns, + removed_by_renames, + ) - return _ProjectSpec( - cols=op.get_current_projection(), - cols_remap=None, - exprs=None, + if not is_valid: + # Raise KeyError to match expected error type in tests + raise KeyError( + f"Column(s) {sorted(missing_columns)} not found. " + f"Available columns: {sorted(upstream_output_columns) if not upstream_has_star else 'all columns (has star)'}" ) + + rewritten_exprs: List[Expr] = [] + # Intersection case: This is when downstream is a selection (no star), and we need to recursively rewrite the downstream expressions into the upstream column definitions. + # Example: Upstream: [col("a").alias("b")], Downstream: [col("b").alias("c")] → Rewritten: [col("a").alias("c")] + if not downstream_has_star: + for expr in downstream_project.exprs: + rewritten = _ColumnRewriter(upstream_column_definitions).visit(expr) + rewritten_exprs.append(rewritten) else: - raise ValueError( - f"Operation doesn't have projection spec (supported Project, " - f"Read, got: {op.__class__})" + # Composition case: downstream has star(), and we need to merge both upstream and downstream expressions. + # Example: + # Upstream: [star(), col("a").alias("b")], Downstream: [star(), col("b").alias("c")] → Rewritten: [star(), col("a").alias("b"), col("b").alias("c")] + rewritten_exprs = _compose_projects( + upstream_project, + downstream_project, + upstream_has_star, ) - -def _combine_projection_specs( - prev_spec: _ProjectSpec, new_spec: _ProjectSpec -) -> _ProjectSpec: - combined_cols_remap = _combine_columns_remap( - prev_spec.cols_remap, - new_spec.cols_remap, + return Project( + upstream_project.input_dependency, + exprs=rewritten_exprs, + ray_remote_args=downstream_project._ray_remote_args, ) - # Validate resulting remapping against existing projection (if any) - _validate(combined_cols_remap, prev_spec.cols) - - new_projection_cols: Optional[List[str]] - - if prev_spec.cols is None and new_spec.cols is None: - # If both projections are unset, resulting is unset - new_projection_cols = None - elif prev_spec.cols is not None and new_spec.cols is None: - # If previous projection is set, but the new unset -- fallback to - # existing projection - new_projection_cols = prev_spec.cols - else: - # If new is set (and previous is either set or not) - # - Reconcile new projection - # - Project combined column remapping - assert new_spec.cols is not None - - new_projection_cols = new_spec.cols - - # Remap new projected columns into the schema before remapping (from the - # previous spec) - if prev_spec.cols_remap and new_projection_cols: - # Inverse remapping - inv_cols_remap = {v: k for k, v in prev_spec.cols_remap.items()} - new_projection_cols = [ - inv_cols_remap.get(col, col) for col in new_projection_cols - ] - - prev_cols_set = set(prev_spec.cols or []) - new_cols_set = set(new_projection_cols or []) - - # Validate new projection is a proper subset of the previous one - if prev_cols_set and new_cols_set and not new_cols_set.issubset(prev_cols_set): - raise ValueError( - f"Selected columns '{new_cols_set}' needs to be a subset of " - f"'{prev_cols_set}'" - ) - - # Project remaps to only map relevant columns - if new_projection_cols is not None and combined_cols_remap is not None: - projected_cols_remap = { - k: v for k, v in combined_cols_remap.items() if k in new_projection_cols - } - else: - projected_cols_remap = combined_cols_remap - - # Combine expressions from both specs - combined_exprs = _combine_expressions(prev_spec.exprs, new_spec.exprs) - return _ProjectSpec( - cols=new_projection_cols, cols_remap=projected_cols_remap, exprs=combined_exprs - ) +class ProjectionPushdown(Rule): + """ + Optimization rule that pushes projections (column selections) down the query plan. + This rule performs two optimizations: + 1. Fuses consecutive Project operations to eliminate redundant projections + 2. Pushes projections into data sources (e.g., Read operations) to enable + column pruning at the storage layer + """ -def _combine_columns_remap( - prev_remap: Optional[Dict[str, str]], new_remap: Optional[Dict[str, str]] -) -> Optional[Dict[str, str]]: + def apply(self, plan: LogicalPlan) -> LogicalPlan: + """Apply projection pushdown optimization to the entire plan.""" + dag = plan.dag + new_dag = dag._apply_transform(self._try_fuse_projects) + new_dag = new_dag._apply_transform(self._push_projection_into_read_op) + return LogicalPlan(new_dag, plan.context) if dag is not new_dag else plan - if not new_remap and not prev_remap: - return None + @classmethod + def _try_fuse_projects(cls, op: LogicalOperator) -> LogicalOperator: + """ + Optimize a single Project operator. - new_remap = new_remap or {} - base_remap = prev_remap or {} + Steps: + 1. Iteratively fuse with upstream Project operations + 2. Push the resulting projection into the data source if possible + """ + if not isinstance(op, Project): + return op - filtered_new_remap = dict(new_remap) - # Apply new remapping to the base remap - updated_base_remap = { - # NOTE: We're removing corresponding chained mapping from the remap - k: filtered_new_remap.pop(v, v) - for k, v in base_remap.items() - } + # Step 1: Iteratively fuse with upstream Project operations + current_project: Project = op - resolved_remap = dict(updated_base_remap) - resolved_remap.update(filtered_new_remap) + if not isinstance(current_project.input_dependency, Project): + return op - return resolved_remap + upstream_project: Project = current_project.input_dependency # type: ignore[assignment] + return _try_fuse_consecutive_projects(upstream_project, current_project) + @classmethod + def _push_projection_into_read_op(cls, op: LogicalOperator) -> LogicalOperator: -def _validate(remap: Optional[Dict[str, str]], projection_cols: Optional[List[str]]): - if not remap: - return + if not isinstance(op, Project): + return op - # Verify that the remapping is a proper bijection (ie no - # columns are renamed into the same new name) - prev_names_map = {} - for prev_name, new_name in remap.items(): - if new_name in prev_names_map: - raise ValueError( - f"Identified projections with conflict in renaming: '{new_name}' " - f"is mapped from multiple sources: '{prev_names_map[new_name]}' " - f"and '{prev_name}'." - ) + current_project: Project = op - prev_names_map[new_name] = prev_name + # Step 2: Push projection into the data source if supported + input_op = current_project.input_dependency + if ( + not current_project.has_star_expr() # Must be a selection, not additive + and isinstance(input_op, LogicalOperatorSupportsProjectionPushdown) + and input_op.supports_projection_pushdown() + ): + required_columns = _collect_referenced_columns(list(current_project.exprs)) + if required_columns is not None: # None means star() was present + optimized_source = input_op.apply_projection(list(required_columns)) - # Verify that remapping only references columns available in the projection - if projection_cols is not None: - invalid_cols = [key for key in remap.keys() if key not in projection_cols] + is_simple_selection = all( + isinstance(expr, ColumnExpr) for expr in current_project.exprs + ) - if invalid_cols: - raise ValueError( - f"Identified projections with invalid rename " - f"columns: {', '.join(invalid_cols)}" - ) + if is_simple_selection: + # Simple column selection: Read handles everything + return optimized_source + else: + # Has transformations: Keep Project on top of optimized Read + return Project( + optimized_source, + exprs=current_project.exprs, + ray_remote_args=current_project._ray_remote_args, + ) + + return current_project diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 58595fcce385..71fd567b433b 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -3,7 +3,7 @@ import ast import logging import operator -from typing import Any, Callable, Dict, TypeVar, Union +from typing import Any, Callable, Dict, List, TypeVar, Union import numpy as np import pandas as pd @@ -691,3 +691,96 @@ def eval_expr(expr: Expr, block: Block) -> Union[BlockColumn, ScalarType]: """ evaluator = NativeExpressionEvaluator(block) return evaluator.visit(expr) + + +def eval_projection(exprs: List[Expr], block: Block) -> Block: + """ + Evaluate a projection (list of expressions) against a block. + + Handles projection semantics including: + - Empty projections + - Star() expressions for preserving existing columns + - Rename detection + - Column ordering + + Args: + exprs: List of expressions to evaluate (may include StarExpr) + block: The block to project + + Returns: + A new block with the projected schema + """ + block_accessor = BlockAccessor.for_block(block) + + # Skip projection only for schema-less empty blocks. + if block_accessor.num_rows() == 0 and len(block_accessor.column_names()) == 0: + return block + + existing_cols = list(block_accessor.column_names()) + + # Handle simple cases early. + if len(exprs) == 0: + return block_accessor.select([]) + + if len(exprs) == 1 and isinstance(exprs[0], StarExpr): + return block + + # Helper function to check if expression is a simple rename of existing column. + def is_simple_rename(expr: Expr) -> bool: + return isinstance(expr, AliasExpr) and expr.expr.name in existing_cols + + # Build rename map: src_name -> dest_name for simple renames. + rename_map: Dict[str, str] = {} + for expr in exprs: + if is_simple_rename(expr) and expr.expr.name != expr.name: + rename_map[expr.expr.name] = expr.name + + # Expand stars to determine output column order. + non_star_exprs = [e for e in exprs if not isinstance(e, StarExpr)] + output_order: List[str] = [] + for expr in exprs: + if isinstance(expr, StarExpr): + output_order.extend(rename_map.get(c, c) for c in existing_cols) + else: + output_order.append(expr.name) + + # Determine which source columns should be dropped in final output. + # ANY alias where src != dest should drop 'src' if no later expression outputs 'src'. + drop_sources: set[str] = set() + for idx, expr in enumerate(non_star_exprs): + # Check if this is ANY alias (not just simple renames of existing cols) + if isinstance(expr, AliasExpr): + src = expr.expr.name + dest = expr.name + if src != dest: + # Check if any expression AFTER this one outputs the source name. + if not any(e.name == src for e in non_star_exprs[idx + 1 :]): + drop_sources.add(src) + + # Evaluate expressions sequentially. + cur_block = block + seen_outputs = set() + + for expr in non_star_exprs: + output_name = expr.name + + # Check for duplicate output names. + if output_name in seen_outputs: + raise ValueError(f"Column name '{output_name}' is a duplicate.") + seen_outputs.add(output_name) + + # Simple renames evaluate against original block to preserve swap semantics. + # Other expressions evaluate against current block to access prior outputs. + source_block = block if is_simple_rename(expr) else cur_block + value = eval_expr(expr, source_block) + cur_block = BlockAccessor.for_block(cur_block).fill_column(output_name, value) + + # Build final column list: deduplicate output_order and exclude dropped sources. + final_cols: List[str] = [] + seen_final = set() + for name in output_order: + if name not in drop_sources and name not in seen_final: + final_cols.append(name) + seen_final.add(name) + + return BlockAccessor.for_block(cur_block).select(final_cols) diff --git a/python/ray/data/_internal/planner/plan_expression/expression_visitors.py b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py new file mode 100644 index 000000000000..9162f417f3e5 --- /dev/null +++ b/python/ray/data/_internal/planner/plan_expression/expression_visitors.py @@ -0,0 +1,244 @@ +from typing import Set, TypeVar + +from ray.data.expressions import ( + AliasExpr, + BinaryExpr, + ColumnExpr, + Expr, + LiteralExpr, + StarExpr, + UDFExpr, + UnaryExpr, + _ExprVisitor, +) + +T = TypeVar("T") + + +class _ExprVisitorBase(_ExprVisitor[None]): + """Base visitor that provides automatic recursive traversal. + + This class extends _ExprVisitor and provides default implementations + for composite nodes that automatically traverse child expressions. + """ + + def visit_binary(self, expr: "BinaryExpr") -> None: + """Default implementation: recursively visit both operands.""" + super().visit(expr.left) + super().visit(expr.right) + + def visit_unary(self, expr: "UnaryExpr") -> None: + """Default implementation: recursively visit the operand.""" + super().visit(expr.operand) + + def visit_alias(self, expr: "AliasExpr") -> None: + """Default implementation: recursively visit the inner expression.""" + super().visit(expr.expr) + + def visit_udf(self, expr: "UDFExpr") -> None: + """Default implementation: recursively visit all arguments.""" + for arg in expr.args: + super().visit(arg) + for value in expr.kwargs.values(): + super().visit(value) + + +class _ColumnReferenceCollector(_ExprVisitorBase): + """Visitor that collects all column references from expression trees. + + This visitor traverses expression trees and accumulates column names + referenced in ColumnExpr nodes. + """ + + def __init__(self): + """Initialize with an empty set of referenced columns.""" + self.referenced_columns: Set[str] = set() + + def visit_column(self, expr: ColumnExpr) -> None: + """Visit a column expression and collect its name. + + Args: + expr: The column expression. + + Returns: + None (only collects columns as a side effect). + """ + self.referenced_columns.add(expr.name) + + def visit_alias(self, expr: AliasExpr) -> None: + """Visit an alias expression and collect from its inner expression. + + Args: + expr: The alias expression. + + Returns: + None (only collects columns as a side effect). + """ + self.visit(expr.expr) + + def visit_literal(self, expr: LiteralExpr) -> None: + """Visit a literal expression (no columns to collect).""" + pass + + def visit_star(self, expr: StarExpr) -> None: + """Visit a star expression (no columns to collect).""" + pass + + def visit_download(self, expr: "Expr") -> None: + """Visit a download expression (no columns to collect).""" + pass + + +class _ColumnRewriter(_ExprVisitor[Expr]): + """Visitor that rewrites column references in expression trees. + + This visitor traverses expression trees and substitutes column references + according to a provided substitution map, preserving the structure of the tree. + """ + + def __init__(self, column_substitutions: dict[str, Expr]): + """Initialize with a column substitution map. + + Args: + column_substitutions: Mapping from column names to replacement expressions. + """ + self.column_substitutions = column_substitutions + self._currently_substituting: Set[ + str + ] = set() # Track columns being substituted to prevent cycles + + def visit(self, expr: Expr) -> Expr: + """Visit an expression node and return the rewritten expression. + + Args: + expr: The expression to visit. + + Returns: + The rewritten expression. + """ + return super().visit(expr) + + def visit_column(self, expr: ColumnExpr) -> Expr: + """Visit a column expression and substitute it. + + Args: + expr: The column expression. + + Returns: + The substituted expression or the original if no substitution exists. + """ + # Check for cycles: if we're already substituting this column, stop + if expr.name in self._currently_substituting: + return expr + + substitution = self.column_substitutions.get(expr.name) + if substitution is None: + return expr + + # Mark this column as being substituted + self._currently_substituting.add(expr.name) + + try: + if not isinstance(substitution, AliasExpr): + # Non-aliased expression: recursively rewrite + return self.visit(substitution) + + inner = substitution.expr + if isinstance(inner, ColumnExpr): + inner_def = self.column_substitutions.get(inner.name) + if isinstance(inner_def, AliasExpr) and isinstance( + inner_def.expr, ColumnExpr + ): + # Preserve simple rename chain (swap semantics -> Example: [col("a").alias("b"), col("b").alias("a")]) + return substitution + + # Aliased expression: rewrite inner and preserve alias (unless preserved above) + return self.visit(inner).alias(substitution.name) + finally: + # Remove from tracking when done + self._currently_substituting.discard(expr.name) + + def visit_literal(self, expr: LiteralExpr) -> Expr: + """Visit a literal expression (no rewriting needed). + + Args: + expr: The literal expression. + + Returns: + The original literal expression. + """ + return expr + + def visit_binary(self, expr: BinaryExpr) -> Expr: + """Visit a binary expression and rewrite its operands. + + Args: + expr: The binary expression. + + Returns: + A new binary expression with rewritten operands. + """ + return BinaryExpr( + expr.op, + self.visit(expr.left), + self.visit(expr.right), + ) + + def visit_unary(self, expr: UnaryExpr) -> Expr: + """Visit a unary expression and rewrite its operand. + + Args: + expr: The unary expression. + + Returns: + A new unary expression with rewritten operand. + """ + return UnaryExpr(expr.op, self.visit(expr.operand)) + + def visit_udf(self, expr: UDFExpr) -> Expr: + """Visit a UDF expression and rewrite its arguments. + + Args: + expr: The UDF expression. + + Returns: + A new UDF expression with rewritten arguments. + """ + new_args = [self.visit(arg) for arg in expr.args] + new_kwargs = {key: self.visit(value) for key, value in expr.kwargs.items()} + return UDFExpr( + fn=expr.fn, data_type=expr.data_type, args=new_args, kwargs=new_kwargs + ) + + def visit_alias(self, expr: AliasExpr) -> Expr: + """Visit an alias expression and rewrite its inner expression. + + Args: + expr: The alias expression. + + Returns: + A new alias expression with rewritten inner expression and preserved name. + """ + return self.visit(expr.expr).alias(expr.name) + + def visit_download(self, expr: "Expr") -> Expr: + """Visit a download expression (no rewriting needed). + + Args: + expr: The download expression. + + Returns: + The original download expression. + """ + return expr + + def visit_star(self, expr: StarExpr) -> Expr: + """Visit a star expression (no rewriting needed). + + Args: + expr: The star expression. + + Returns: + The original star expression. + """ + return expr diff --git a/python/ray/data/_internal/planner/plan_udf_map_op.py b/python/ray/data/_internal/planner/plan_udf_map_op.py index d6a0bde03f78..148dbf168aaa 100644 --- a/python/ray/data/_internal/planner/plan_udf_map_op.py +++ b/python/ray/data/_internal/planner/plan_udf_map_op.py @@ -48,7 +48,6 @@ ) from ray.data._internal.numpy_support import _is_valid_column_values from ray.data._internal.output_buffer import OutputBlockSizeOption -from ray.data._internal.planner.plan_expression.expression_evaluator import eval_expr from ray.data._internal.util import _truncated_repr from ray.data.block import ( Block, @@ -113,54 +112,25 @@ def plan_project_op( assert len(physical_children) == 1 input_physical_dag = physical_children[0] - columns = op.cols - columns_rename = op.cols_rename - exprs = op.exprs - def _project_block(block: Block) -> Block: try: - block_accessor = BlockAccessor.for_block(block) - if not block_accessor.num_rows(): - return block - - # 1. evaluate / add expressions - if exprs: - block_accessor = BlockAccessor.for_block(block) - # Add/update with expression results - result_block = block - for name, expr in exprs.items(): - # Use expr.name if available, otherwise fall back to the dict key name - actual_name = expr.name if expr.name is not None else name - result = eval_expr(expr, result_block) - result_block_accessor = BlockAccessor.for_block(result_block) - # fill_column handles both scalars and arrays - result_block = result_block_accessor.fill_column( - actual_name, result - ) - block = result_block - - # 2. (optional) column projection - if columns: - block = BlockAccessor.for_block(block).select(columns) + from ray.data._internal.planner.plan_expression.expression_evaluator import ( + eval_projection, + ) - # 3. (optional) rename - if columns_rename: - block = block.rename_columns( - [columns_rename.get(col, col) for col in block.schema.names] - ) - - return block + return eval_projection(op.exprs, block) except Exception as e: _try_wrap_udf_exception(e) compute = get_compute(op._compute) - map_transformer = MapTransformer( [ - BlockMapTransformFn(_generate_transform_fn_for_map_block(_project_block)), + BlockMapTransformFn( + _generate_transform_fn_for_map_block(_project_block), + disable_block_shaping=(len(op.exprs) == 0), + ) ] ) - return MapOperator.create( map_transformer, input_physical_dag, diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index ed497a3b15a8..9b1f26e14617 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -136,7 +136,7 @@ from ray.data._internal.execution.interfaces import Executor, NodeIdStr from ray.data.grouped_data import GroupedData -from ray.data.expressions import Expr +from ray.data.expressions import Expr, star logger = logging.getLogger(__name__) @@ -838,9 +838,7 @@ def with_column( else: project_op = Project( self._logical_plan.dag, - cols=None, - cols_rename=None, - exprs={column_name: expr}, + exprs=[star(), expr.alias(column_name)], ray_remote_args=ray_remote_args, ) logical_plan = LogicalPlan(project_op, self.context) @@ -1067,24 +1065,25 @@ def select_columns( Ray (e.g., num_gpus=1 to request GPUs for the map tasks). See :func:`ray.remote` for details. """ # noqa: E501 + from ray.data.expressions import col + if isinstance(cols, str): - cols = [cols] + exprs = [col(cols)] elif isinstance(cols, list): if not all(isinstance(col, str) for col in cols): raise ValueError( "select_columns requires all elements of 'cols' to be strings." ) + if len(cols) != len(set(cols)): + raise ValueError( + "select_columns expected unique column names, " + f"got duplicate column names: {cols}" + ) + exprs = [col(c) for c in cols] else: raise TypeError( "select_columns requires 'cols' to be a string or a list of strings." ) - - if len(cols) != len(set(cols)): - raise ValueError( - "select_columns expected unique column names, " - f"got duplicate column names: {cols}" - ) - # Don't feel like we really need this from ray.data._internal.compute import TaskPoolStrategy @@ -1093,8 +1092,7 @@ def select_columns( plan = self._plan.copy() select_op = Project( self._logical_plan.dag, - cols=cols, - cols_rename=None, + exprs=exprs, compute=compute, ray_remote_args=ray_remote_args, ) @@ -1157,6 +1155,10 @@ def rename_columns( :func:`ray.remote` for details. """ # noqa: E501 + from ray.data.expressions import col + + exprs = [] + if isinstance(names, dict): if not names: raise ValueError("rename_columns received 'names' with no entries.") @@ -1173,8 +1175,8 @@ def rename_columns( "rename_columns requires both keys and values in the 'names' " "to be strings." ) + exprs = [col(old).alias(new) for old, new in names.items()] - cols_rename = names elif isinstance(names, list): if not names: raise ValueError( @@ -1198,7 +1200,10 @@ def rename_columns( f"schema names: {current_names}." ) - cols_rename = dict(zip(current_names, names)) + exprs = [ + col(old).alias(new) + for old, new in dict(zip(current_names, names)).items() + ] else: raise TypeError( f"rename_columns expected names to be either List[str] or " @@ -1219,8 +1224,7 @@ def rename_columns( plan = self._plan.copy() select_op = Project( self._logical_plan.dag, - cols=None, - cols_rename=cols_rename, + exprs=[star()] + exprs, compute=compute, ray_remote_args=ray_remote_args, ) @@ -3592,7 +3596,7 @@ def count(self) -> int: # NOTE: Project the dataset to avoid the need to carry actual # data when we're only interested in the total count - count_op = Count(Project(self._logical_plan.dag, cols=[])) + count_op = Count(Project(self._logical_plan.dag, exprs=[])) logical_plan = LogicalPlan(count_op, self.context) count_ds = Dataset(plan, logical_plan) diff --git a/python/ray/data/expressions.py b/python/ray/data/expressions.py index e1a915b466f8..5f598370237c 100644 --- a/python/ray/data/expressions.py +++ b/python/ray/data/expressions.py @@ -508,7 +508,10 @@ class UnaryExpr(Expr): op: Operation operand: Expr - data_type: DataType = field(init=False) + # Default to bool return dtype for unary operations like is_null() and NOT. + # This enables chaining operations such as col("x").is_not_null().alias("valid"), + # where downstream expressions (like AliasExpr) need the data type. + data_type: DataType = field(default_factory=lambda: DataType.bool(), init=False) def structurally_equals(self, other: Any) -> bool: return ( diff --git a/python/ray/data/tests/test_execution_optimizer_basic.py b/python/ray/data/tests/test_execution_optimizer_basic.py index 959d5d0c4af6..1f149aa5d1db 100644 --- a/python/ray/data/tests/test_execution_optimizer_basic.py +++ b/python/ray/data/tests/test_execution_optimizer_basic.py @@ -31,6 +31,7 @@ from ray.data.context import DataContext from ray.data.datasource import Datasource from ray.data.datasource.datasource import ReadTask +from ray.data.expressions import col from ray.data.tests.conftest import * # noqa from ray.data.tests.test_util import _check_usage_record, get_parquet_read_logical_op from ray.data.tests.util import column_udf, extract_values, named_values @@ -283,7 +284,7 @@ def test_project_operator_select(ray_start_regular_shared_2_cpus): logical_plan = ds._plan._logical_plan op = logical_plan.dag assert isinstance(op, Project), op.name - assert op.cols == cols + assert op.exprs == [col("sepal.length"), col("petal.width")] physical_plan = create_planner().plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) @@ -297,6 +298,8 @@ def test_project_operator_rename(ray_start_regular_shared_2_cpus): Checks that the physical plan is properly generated for the Project operator from rename columns. """ + from ray.data.expressions import star + path = "example://iris.parquet" ds = ray.data.read_parquet(path) ds = ds.map_batches(lambda d: d) @@ -306,9 +309,11 @@ def test_project_operator_rename(ray_start_regular_shared_2_cpus): logical_plan = ds._plan._logical_plan op = logical_plan.dag assert isinstance(op, Project), op.name - assert not op.cols - assert op.cols_rename == cols_rename - + assert op.exprs == [ + star(), + col("sepal.length").alias("sepal_length"), + col("petal.width").alias("pedal_width"), + ] physical_plan = create_planner().plan(logical_plan) physical_plan = PhysicalOptimizer().optimize(physical_plan) physical_op = physical_plan.dag diff --git a/python/ray/data/tests/test_operator_fusion.py b/python/ray/data/tests/test_operator_fusion.py index 578a0e8be099..ee9abb1bd7f3 100644 --- a/python/ray/data/tests/test_operator_fusion.py +++ b/python/ray/data/tests/test_operator_fusion.py @@ -26,6 +26,7 @@ from ray.data._internal.stats import DatasetStats from ray.data.context import DataContext from ray.data.dataset import Dataset +from ray.data.expressions import star from ray.data.tests.conftest import * # noqa from ray.data.tests.test_util import _check_usage_record, get_parquet_read_logical_op from ray.data.tests.util import column_udf, extract_values @@ -316,7 +317,7 @@ def test_read_with_map_batches_fused_successfully( ), ( # Fusion - Project(InputData([])), + Project(InputData([]), exprs=[star()]), True, ), ], diff --git a/python/ray/data/tests/test_projection_fusion.py b/python/ray/data/tests/test_projection_fusion.py index 9855e93e31ab..5649f12dbbb0 100644 --- a/python/ray/data/tests/test_projection_fusion.py +++ b/python/ray/data/tests/test_projection_fusion.py @@ -14,7 +14,7 @@ ProjectionPushdown, ) from ray.data.context import DataContext -from ray.data.expressions import DataType, col, udf +from ray.data.expressions import DataType, StarExpr, col, star, udf @dataclass @@ -38,7 +38,7 @@ class DependencyTestCase: description: str -class TestPorjectionFusion: +class TestProjectionFusion: """Test topological sorting in projection pushdown fusion.""" @pytest.fixture(autouse=True) @@ -114,12 +114,14 @@ def _create_project_chain(self, input_op, expressions_list: List[Dict[str, str]] current_op = input_op for expr_dict in expressions_list: - exprs = { - name: self._parse_expression(desc) for name, desc in expr_dict.items() - } - current_op = Project( - current_op, cols=None, cols_rename=None, exprs=exprs, ray_remote_args={} - ) + # Convert dictionary to list of named expressions + exprs = [] + for name, desc in expr_dict.items(): + expr = self._parse_expression(desc) + named_expr = expr.alias(name) + exprs.append(named_expr) + + current_op = Project(current_op, exprs=[star()] + exprs, ray_remote_args={}) return current_op @@ -129,8 +131,10 @@ def _extract_levels_from_plan(self, plan: LogicalPlan) -> List[Set[str]]: levels = [] while isinstance(current, Project): - if current.exprs: - levels.append(set(current.exprs.keys())) + # Extract names, ignoring StarExpr (not a named column) + levels.append( + {expr.name for expr in current.exprs if not isinstance(expr, StarExpr)} + ) current = current.input_dependency return list(reversed(levels)) # Return bottom-up order @@ -604,7 +608,7 @@ def test_chained_udf_dependencies(self): assert self._count_project_operators(optimized_plan) == 1 assert ( self._describe_plan_structure(optimized_plan) - == "Project(3 exprs) -> FromItems" # Changed from multiple operators + == "Project(4 exprs) -> FromItems" # Changed from multiple operators ) # Verify execution correctness @@ -666,13 +670,677 @@ def test_performance_impact_of_udf_chains(self): ) # Changed from 3 to 1 assert ( self._describe_plan_structure(optimized_independent) - == "Project(3 exprs) -> FromItems" + == "Project(4 exprs) -> FromItems" ) assert ( self._describe_plan_structure(optimized_chained) - == "Project(3 exprs) -> FromItems" # Changed from multiple operators + == "Project(4 exprs) -> FromItems" # Changed from multiple operators ) + @pytest.mark.parametrize( + "operations,expected", + [ + # Single operations + ([("rename", {"a": "A"})], {"A": 1, "b": 2, "c": 3}), + ([("select", ["a", "b"])], {"a": 1, "b": 2}), + ([("with_column", "d", 4)], {"a": 1, "b": 2, "c": 3, "d": 4}), + # Two operations - rename then select + ([("rename", {"a": "A"}), ("select", ["A"])], {"A": 1}), + ([("rename", {"a": "A"}), ("select", ["b"])], {"b": 2}), + ( + [("rename", {"a": "A", "b": "B"}), ("select", ["A", "B"])], + {"A": 1, "B": 2}, + ), + # Two operations - select then rename + ([("select", ["a", "b"]), ("rename", {"a": "A"})], {"A": 1, "b": 2}), + ([("select", ["a"]), ("rename", {"a": "x"})], {"x": 1}), + # Two operations - with_column combinations + ([("with_column", "d", 4), ("select", ["a", "d"])], {"a": 1, "d": 4}), + ([("select", ["a"]), ("with_column", "d", 4)], {"a": 1, "d": 4}), + ( + [("rename", {"a": "A"}), ("with_column", "d", 4)], + {"A": 1, "b": 2, "c": 3, "d": 4}, + ), + ( + [("with_column", "d", 4), ("rename", {"d": "D"})], + {"a": 1, "b": 2, "c": 3, "D": 4}, + ), + # Three operations + ( + [ + ("rename", {"a": "A"}), + ("select", ["A", "b"]), + ("with_column", "d", 4), + ], + {"A": 1, "b": 2, "d": 4}, + ), + ( + [ + ("with_column", "d", 4), + ("rename", {"a": "A"}), + ("select", ["A", "d"]), + ], + {"A": 1, "d": 4}, + ), + ( + [ + ("select", ["a", "b"]), + ("rename", {"a": "x"}), + ("with_column", "d", 4), + ], + {"x": 1, "b": 2, "d": 4}, + ), + # Column swap + ([("rename", {"a": "b", "b": "a"}), ("select", ["a"])], {"a": 2}), + ([("rename", {"a": "b", "b": "a"}), ("select", ["b"])], {"b": 1}), + # Multiple same operations + ( + [("rename", {"a": "x"}), ("rename", {"x": "y"})], + {"y": 1, "b": 2, "c": 3}, + ), + ([("select", ["a", "b"]), ("select", ["a"])], {"a": 1}), + ( + [("with_column", "d", 4), ("with_column", "e", 5)], + {"a": 1, "b": 2, "c": 3, "d": 4, "e": 5}, + ), + # Complex expressions with with_column + ( + [("rename", {"a": "x"}), ("with_column_expr", "sum", "x", 10)], + {"x": 1, "b": 2, "c": 3, "sum": 10}, + ), + ( + [ + ("with_column", "d", 4), + ("with_column", "e", 5), + ("select", ["d", "e"]), + ], + {"d": 4, "e": 5}, + ), + ], + ) + def test_projection_operations_comprehensive(self, operations, expected): + """Comprehensive test for projection operations combinations.""" + from ray.data.expressions import col, lit + + # Create initial dataset + ds = ray.data.range(1).map(lambda row: {"a": 1, "b": 2, "c": 3}) + + # Apply operations + for op in operations: + if op[0] == "rename": + ds = ds.rename_columns(op[1]) + elif op[0] == "select": + ds = ds.select_columns(op[1]) + elif op[0] == "with_column": + ds = ds.with_column(op[1], lit(op[2])) + elif op[0] == "with_column_expr": + # Special case for expressions referencing columns + ds = ds.with_column(op[1], col(op[2]) * op[3]) + + # Verify result + result = ds.take_all() + assert len(result) == 1 + assert result[0] == expected + + @pytest.mark.parametrize( + "operations,expected", + [ + # Basic count operations + ([("count",)], 3), # All 3 rows + ([("rename", {"a": "A"}), ("count",)], 3), + ([("select", ["a", "b"]), ("count",)], 3), + ([("with_column", "d", 4), ("count",)], 3), + # Filter operations affecting count + ([("filter", col("a") > 1), ("count",)], 2), # 2 rows have a > 1 + ([("filter", col("b") == 2), ("count",)], 3), # All rows have b == 2 + ([("filter", col("c") < 10), ("count",)], 3), # All rows have c < 10 + ([("filter", col("a") == 1), ("count",)], 1), # 1 row has a == 1 + # Projection then filter then count + ([("rename", {"a": "A"}), ("filter", col("A") > 1), ("count",)], 2), + ([("select", ["a", "b"]), ("filter", col("a") > 1), ("count",)], 2), + ([("with_column", "d", 4), ("filter", col("d") == 4), ("count",)], 3), + # Filter then projection then count + ([("filter", col("a") > 1), ("rename", {"a": "A"}), ("count",)], 2), + ([("filter", col("b") == 2), ("select", ["a", "b"]), ("count",)], 3), + ([("filter", col("c") < 10), ("with_column", "d", 4), ("count",)], 3), + # Multiple projections with filter and count + ( + [ + ("rename", {"a": "A"}), + ("select", ["A", "b"]), + ("filter", col("A") > 1), + ("count",), + ], + 2, + ), + ( + [ + ("with_column", "d", 4), + ("rename", {"d": "D"}), + ("filter", col("D") == 4), + ("count",), + ], + 3, + ), + ( + [ + ("select", ["a", "b"]), + ("filter", col("a") > 1), + ("rename", {"a": "x"}), + ("count",), + ], + 2, + ), + # Complex combinations + ( + [ + ("filter", col("a") > 0), + ("rename", {"b": "B"}), + ("select", ["a", "B"]), + ("filter", col("B") == 2), + ("count",), + ], + 3, + ), + ( + [ + ("with_column", "sum", 99), + ("filter", col("a") > 1), + ("select", ["a", "sum"]), + ("count",), + ], + 2, + ), + ( + [ + ("rename", {"a": "A", "b": "B"}), + ("filter", (col("A") + col("B")) > 3), + ("select", ["A"]), + ("count",), + ], + 2, + ), + ], + ) + def test_projection_fusion_with_count_and_filter(self, operations, expected): + """Test projection fusion with count operations including filters.""" + from ray.data.expressions import lit + + # Create dataset with 3 rows: {"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 2, "c": 3}, {"a": 3, "b": 2, "c": 3} + ds = ray.data.from_items( + [ + {"a": 1, "b": 2, "c": 3}, + {"a": 2, "b": 2, "c": 3}, + {"a": 3, "b": 2, "c": 3}, + ] + ) + + # Apply operations + for op in operations: + if op[0] == "rename": + ds = ds.rename_columns(op[1]) + elif op[0] == "select": + ds = ds.select_columns(op[1]) + elif op[0] == "with_column": + ds = ds.with_column(op[1], lit(op[2])) + elif op[0] == "filter": + # Use the predicate expression directly + ds = ds.filter(expr=op[1]) + elif op[0] == "count": + # Count returns a scalar, not a dataset + result = ds.count() + assert result == expected + return # Early return since count() terminates the pipeline + + # This should not be reached for count operations + assert False, "Count operation should have returned early" + + @pytest.mark.parametrize( + "invalid_operations,error_type,error_message_contains", + [ + # Try to filter on a column that doesn't exist yet + ( + [("filter", col("d") > 0), ("with_column", "d", 4)], + (KeyError, ray.exceptions.RayTaskError), + "d", + ), + # Try to filter on a renamed column before the rename + ( + [("filter", col("A") > 1), ("rename", {"a": "A"})], + (KeyError, ray.exceptions.RayTaskError), + "A", + ), + # Try to use a column that was removed by select + ( + [("select", ["a"]), ("filter", col("b") == 2)], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to filter on a column after it was removed by select + ( + [("select", ["a", "b"]), ("filter", col("c") < 10)], + (KeyError, ray.exceptions.RayTaskError), + "c", + ), + # Try to use with_column referencing a non-existent column + ( + [("select", ["a"]), ("with_column", "new_col", col("b") + 1)], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to filter on a column that was renamed away + ( + [("rename", {"b": "B"}), ("filter", col("b") == 2)], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to use with_column with old column name after rename + ( + [("rename", {"a": "A"}), ("with_column", "result", col("a") + 1)], + (KeyError, ray.exceptions.RayTaskError), + "a", + ), + # Try to select using old column name after rename + ( + [("rename", {"b": "B"}), ("select", ["a", "b", "c"])], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + # Try to filter on a computed column that was removed by select + ( + [ + ("with_column", "d", 4), + ("select", ["a", "b"]), + ("filter", col("d") == 4), + ], + (KeyError, ray.exceptions.RayTaskError), + "d", + ), + # Try to rename a column that was removed by select + ( + [("select", ["a", "b"]), ("rename", {"c": "C"})], + (KeyError, ray.exceptions.RayTaskError), + "c", + ), + # Complex: rename, select (removing renamed source), then use old name + ( + [ + ("rename", {"a": "A"}), + ("select", ["b", "c"]), + ("filter", col("a") > 0), + ], + (KeyError, ray.exceptions.RayTaskError), + "a", + ), + # Complex: with_column, select (keeping new column), filter on removed original + ( + [ + ("with_column", "sum", col("a") + col("b")), + ("select", ["sum"]), + ("filter", col("a") > 0), + ], + (KeyError, ray.exceptions.RayTaskError), + "a", + ), + # Try to use column in with_column expression after it was removed + ( + [ + ("select", ["a", "c"]), + ("with_column", "result", col("a") + col("b")), + ], + (KeyError, ray.exceptions.RayTaskError), + "b", + ), + ], + ) + def test_projection_operations_invalid_order( + self, invalid_operations, error_type, error_message_contains + ): + """Test that operations fail gracefully when referencing non-existent columns.""" + import ray + from ray.data.expressions import lit + + # Create dataset with 3 rows: {"a": 1, "b": 2, "c": 3}, {"a": 2, "b": 2, "c": 3}, {"a": 3, "b": 2, "c": 3} + ds = ray.data.from_items( + [ + {"a": 1, "b": 2, "c": 3}, + {"a": 2, "b": 2, "c": 3}, + {"a": 3, "b": 2, "c": 3}, + ] + ) + + # Apply operations and expect them to fail + with pytest.raises(error_type) as exc_info: + for op in invalid_operations: + if op[0] == "rename": + ds = ds.rename_columns(op[1]) + elif op[0] == "select": + ds = ds.select_columns(op[1]) + elif op[0] == "with_column": + if len(op) == 3 and not isinstance(op[2], (int, float, str)): + # Expression-based with_column (op[2] is an expression) + ds = ds.with_column(op[1], op[2]) + else: + # Literal-based with_column + ds = ds.with_column(op[1], lit(op[2])) + elif op[0] == "filter": + ds = ds.filter(expr=op[1]) + elif op[0] == "count": + ds.count() + return + + # Force execution to trigger the error + result = ds.take_all() + print(f"Unexpected success: {result}") + + # Verify the error message contains the expected column name + error_str = str(exc_info.value).lower() + assert ( + error_message_contains.lower() in error_str + ), f"Expected '{error_message_contains}' in error message: {error_str}" + + @pytest.mark.parametrize( + "operations,expected_output", + [ + # === Basic Select Operations === + pytest.param( + [("select", ["a"])], + [{"a": 1}, {"a": 2}, {"a": 3}], + id="select_single_column", + ), + pytest.param( + [("select", ["a", "b"])], + [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}], + id="select_two_columns", + ), + pytest.param( + [("select", ["a", "b", "c"])], + [ + {"a": 1, "b": 4, "c": 7}, + {"a": 2, "b": 5, "c": 8}, + {"a": 3, "b": 6, "c": 9}, + ], + id="select_all_columns", + ), + pytest.param( + [("select", ["c", "a"])], + [{"c": 7, "a": 1}, {"c": 8, "a": 2}, {"c": 9, "a": 3}], + id="select_reordered_columns", + ), + # === Basic Rename Operations === + pytest.param( + [("rename", {"a": "alpha"})], + [ + {"alpha": 1, "b": 4, "c": 7}, + {"alpha": 2, "b": 5, "c": 8}, + {"alpha": 3, "b": 6, "c": 9}, + ], + id="rename_single_column", + ), + pytest.param( + [("rename", {"a": "alpha", "b": "beta"})], + [ + {"alpha": 1, "beta": 4, "c": 7}, + {"alpha": 2, "beta": 5, "c": 8}, + {"alpha": 3, "beta": 6, "c": 9}, + ], + id="rename_multiple_columns", + ), + # === Basic with_column Operations === + pytest.param( + [("with_column_expr", "sum", "add", "a", "b")], + [ + {"a": 1, "b": 4, "c": 7, "sum": 5}, + {"a": 2, "b": 5, "c": 8, "sum": 7}, + {"a": 3, "b": 6, "c": 9, "sum": 9}, + ], + id="with_column_add_keep_all", + ), + pytest.param( + [("with_column_expr", "product", "multiply", "b", "c")], + [ + {"a": 1, "b": 4, "c": 7, "product": 28}, + {"a": 2, "b": 5, "c": 8, "product": 40}, + {"a": 3, "b": 6, "c": 9, "product": 54}, + ], + id="with_column_multiply_keep_all", + ), + # === Chained Selects === + pytest.param( + [("select", ["a", "b", "c"]), ("select", ["a", "b"])], + [{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}], + id="chained_selects_two_levels", + ), + pytest.param( + [ + ("select", ["a", "b", "c"]), + ("select", ["a", "b"]), + ("select", ["a"]), + ], + [{"a": 1}, {"a": 2}, {"a": 3}], + id="chained_selects_three_levels", + ), + # === Rename → Select === + pytest.param( + [("rename", {"a": "x"}), ("select", ["x", "b"])], + [{"x": 1, "b": 4}, {"x": 2, "b": 5}, {"x": 3, "b": 6}], + id="rename_then_select", + ), + pytest.param( + [("rename", {"a": "x", "c": "z"}), ("select", ["x", "z"])], + [{"x": 1, "z": 7}, {"x": 2, "z": 8}, {"x": 3, "z": 9}], + id="rename_multiple_then_select", + ), + # === Select → Rename === + pytest.param( + [("select", ["a", "b"]), ("rename", {"a": "x"})], + [{"x": 1, "b": 4}, {"x": 2, "b": 5}, {"x": 3, "b": 6}], + id="select_then_rename", + ), + pytest.param( + [("select", ["a", "b", "c"]), ("rename", {"a": "x", "b": "y"})], + [ + {"x": 1, "y": 4, "c": 7}, + {"x": 2, "y": 5, "c": 8}, + {"x": 3, "y": 6, "c": 9}, + ], + id="select_all_then_rename_some", + ), + # === Multiple Renames === + pytest.param( + [("rename", {"a": "x"}), ("rename", {"x": "y"})], + [ + {"y": 1, "b": 4, "c": 7}, + {"y": 2, "b": 5, "c": 8}, + {"y": 3, "b": 6, "c": 9}, + ], + id="chained_renames", + ), + # === with_column → Select === + pytest.param( + [("with_column_expr", "sum", "add", "a", "b"), ("select", ["sum"])], + [{"sum": 5}, {"sum": 7}, {"sum": 9}], + id="with_column_then_select_only_computed", + ), + pytest.param( + [ + ("with_column_expr", "sum", "add", "a", "b"), + ("select", ["a", "sum"]), + ], + [{"a": 1, "sum": 5}, {"a": 2, "sum": 7}, {"a": 3, "sum": 9}], + id="with_column_then_select_mixed", + ), + pytest.param( + [ + ("with_column_expr", "result", "multiply", "b", "c"), + ("select", ["a", "result"]), + ], + [ + {"a": 1, "result": 28}, + {"a": 2, "result": 40}, + {"a": 3, "result": 54}, + ], + id="with_column_select_source_and_computed", + ), + # === Multiple with_column Operations === + pytest.param( + [ + ("with_column_expr", "sum", "add", "a", "b"), + ("with_column_expr", "product", "multiply", "a", "c"), + ], + [ + {"a": 1, "b": 4, "c": 7, "sum": 5, "product": 7}, + {"a": 2, "b": 5, "c": 8, "sum": 7, "product": 16}, + {"a": 3, "b": 6, "c": 9, "sum": 9, "product": 27}, + ], + id="multiple_with_column_keep_all", + ), + pytest.param( + [ + ("with_column_expr", "sum", "add", "a", "b"), + ("with_column_expr", "product", "multiply", "a", "c"), + ("select", ["sum", "product"]), + ], + [ + {"sum": 5, "product": 7}, + {"sum": 7, "product": 16}, + {"sum": 9, "product": 27}, + ], + id="multiple_with_column_then_select", + ), + pytest.param( + [ + ("with_column_expr", "sum", "add", "a", "b"), + ("with_column_expr", "diff", "add", "c", "a"), + ("select", ["sum", "diff"]), + ], + [{"sum": 5, "diff": 8}, {"sum": 7, "diff": 10}, {"sum": 9, "diff": 12}], + id="multiple_with_column_independent_sources", + ), + # === with_column → Rename === + pytest.param( + [ + ("with_column_expr", "sum", "add", "a", "b"), + ("rename", {"sum": "total"}), + ], + [ + {"a": 1, "b": 4, "c": 7, "total": 5}, + {"a": 2, "b": 5, "c": 8, "total": 7}, + {"a": 3, "b": 6, "c": 9, "total": 9}, + ], + id="with_column_then_rename_computed", + ), + # === Rename → with_column === + pytest.param( + [ + ("rename", {"a": "x"}), + ("with_column_expr", "x_plus_b", "add", "x", "b"), + ], + [ + {"x": 1, "b": 4, "c": 7, "x_plus_b": 5}, + {"x": 2, "b": 5, "c": 8, "x_plus_b": 7}, + {"x": 3, "b": 6, "c": 9, "x_plus_b": 9}, + ], + id="rename_then_with_column_using_renamed", + ), + pytest.param( + [ + ("rename", {"a": "x"}), + ("with_column_expr", "result", "add", "x", "b"), + ("select", ["result"]), + ], + [{"result": 5}, {"result": 7}, {"result": 9}], + id="rename_with_column_select_chain", + ), + # === Select → with_column → Select === + pytest.param( + [ + ("select", ["a", "b"]), + ("with_column_expr", "sum", "add", "a", "b"), + ("select", ["a", "sum"]), + ], + [{"a": 1, "sum": 5}, {"a": 2, "sum": 7}, {"a": 3, "sum": 9}], + id="select_with_column_select_chain", + ), + pytest.param( + [ + ("select", ["a", "b", "c"]), + ("with_column_expr", "x", "add", "a", "b"), + ("with_column_expr", "y", "multiply", "b", "c"), + ("select", ["x", "y"]), + ], + [{"x": 5, "y": 28}, {"x": 7, "y": 40}, {"x": 9, "y": 54}], + id="select_multiple_with_column_select_chain", + ), + # === Complex Multi-Step Chains === + pytest.param( + [ + ("select", ["a", "b", "c"]), + ("rename", {"a": "x"}), + ("with_column_expr", "result", "add", "x", "b"), + ("select", ["result", "c"]), + ], + [{"result": 5, "c": 7}, {"result": 7, "c": 8}, {"result": 9, "c": 9}], + id="complex_select_rename_with_column_select", + ), + pytest.param( + [ + ("rename", {"a": "alpha", "b": "beta"}), + ("select", ["alpha", "beta", "c"]), + ("with_column_expr", "sum", "add", "alpha", "beta"), + ("rename", {"sum": "total"}), + ("select", ["total", "c"]), + ], + [{"total": 5, "c": 7}, {"total": 7, "c": 8}, {"total": 9, "c": 9}], + id="complex_five_step_chain", + ), + pytest.param( + [ + ("select", ["a", "b", "c"]), + ("select", ["b", "c"]), + ("select", ["c"]), + ], + [{"c": 7}, {"c": 8}, {"c": 9}], + id="select_chain", + ), + ], + ) + def test_projection_pushdown_into_parquet_read( + self, tmp_path, operations, expected_output + ): + """Test that projection operations fuse and push down into parquet reads. + + Verifies: + - Multiple projections fuse into single operator + - Fused projection pushes down into Read operator + - Only necessary columns are read from parquet + - Results are correct for select, rename, and with_column operations + """ + from ray.data.expressions import col + + # Create test parquet file + df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}) + parquet_path = tmp_path / "test.parquet" + df.to_parquet(parquet_path, index=False) + + # Build pipeline with operations + ds = ray.data.read_parquet(str(parquet_path)) + + for op_type, *op_args in operations: + if op_type == "select": + ds = ds.select_columns(op_args[0]) + elif op_type == "rename": + ds = ds.rename_columns(op_args[0]) + elif op_type == "with_column_expr": + col_name, operator, col1, col2 = op_args + if operator == "add": + ds = ds.with_column(col_name, col(col1) + col(col2)) + elif operator == "multiply": + ds = ds.with_column(col_name, col(col1) * col(col2)) + + result_df = ds.take_all() + assert result_df == expected_output + if __name__ == "__main__": pytest.main([__file__, "-v"])