diff --git a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py index 4af5146ae3231..9413767c7a395 100644 --- a/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py +++ b/python/ray/data/_internal/planner/plan_expression/expression_evaluator.py @@ -1,9 +1,10 @@ import ast import logging -from typing import Any, List, Union import pyarrow as pa import pyarrow.compute as pc +import pyarrow.dataset as ds + logger = logging.getLogger(__name__) @@ -16,7 +17,8 @@ class ExpressionEvaluator: - def get_filters(self, expression: str) -> pc.Expression: + @staticmethod + def get_filters(expression: str) -> ds.Expression: """Parse and evaluate the expression to generate a filter condition. Args: @@ -28,29 +30,16 @@ def get_filters(self, expression: str) -> pc.Expression: """ try: tree = ast.parse(expression, mode="eval") - return self._build_filter_condition(tree.body) + return _ConvertToArrowExpressionVisitor().visit(tree.body) except SyntaxError as e: raise ValueError(f"Invalid syntax in the expression: {expression}") from e except Exception as e: logger.exception(f"Error processing expression: {e}") raise - def _build_filter_condition(self, node) -> Union[pc.Expression, List[Any], str]: - """Recursively evaluate an AST node to build the filter condition. - - Args: - node: The AST node to evaluate, representing part of the expression. - - Returns: - The evaluated result for the node, which could be a - filter condition, list, or field name. - """ - visitor = _ConvertToArrowExpressionVisitor() - return visitor.visit(node) - class _ConvertToArrowExpressionVisitor(ast.NodeVisitor): - def visit_Compare(self, node: ast.Compare) -> pc.Expression: + def visit_Compare(self, node: ast.Compare) -> ds.Expression: """Handle comparison operations (e.g., a == b, a < b, a in b). Args: @@ -62,11 +51,14 @@ def visit_Compare(self, node: ast.Compare) -> pc.Expression: # Handle left operand # TODO Validate columns if isinstance(node.left, ast.Attribute): - left_expr = self.visit(node.left) # Visit and handle attributes + # Visit and handle attributes + left_expr = self.visit(node.left) elif isinstance(node.left, ast.Name): - left_expr = self.visit(node.left) # Treat as a simple field + # Treat as a simple field + left_expr = self.visit(node.left) elif isinstance(node.left, ast.Constant): - left_expr = node.left.value # Constant values are used directly + # Constant values are used directly + left_expr = node.left.value else: raise ValueError(f"Unsupported left operand type: {type(node.left)}") @@ -92,7 +84,7 @@ def visit_Compare(self, node: ast.Compare) -> pc.Expression: else: raise ValueError(f"Unsupported operator type: {op}") - def visit_BoolOp(self, node: ast.BoolOp) -> pc.Expression: + def visit_BoolOp(self, node: ast.BoolOp) -> ds.Expression: """Handle logical operations (e.g., a and b, a or b). Args: @@ -118,8 +110,8 @@ def visit_BoolOp(self, node: ast.BoolOp) -> pc.Expression: return combined_expr - def visit_Name(self, node: ast.Name) -> pc.Expression: - """Handle variable (name) nodes and return them as pc.Expression. + def visit_Name(self, node: ast.Name) -> ds.Expression: + """Handle variable (name) nodes and return them as pa.dataset.Expression. Even if the name contains periods, it's treated as a single string. @@ -127,11 +119,10 @@ def visit_Name(self, node: ast.Name) -> pc.Expression: node: The AST node representing a variable. Returns: - The variable wrapped as a pc.Expression. + The variable wrapped as a pa.dataset.Expression. """ - field_name = ( - node.id - ) # Directly use the field name as a string (even if it contains periods) + # Directly use the field name as a string (even if it contains periods) + field_name = node.id return pc.field(field_name) def visit_Attribute(self, node: ast.Attribute) -> object: @@ -159,21 +150,21 @@ def visit_Attribute(self, node: ast.Attribute) -> object: raise ValueError(f"Unsupported attribute: {node.attr}") - def visit_List(self, node: ast.List) -> pc.Expression: + def visit_List(self, node: ast.List) -> ds.Expression: """Handle list literals. Args: node: The AST node representing a list. Returns: - The list of elements wrapped as a pc.Expression. + The list of elements wrapped as a pa.dataset.Expression. """ elements = [self.visit(elt) for elt in node.elts] return pa.array(elements) - # TODO (srinathk) Note that visit_Constant does not return pc.Expression + # TODO (srinathk) Note that visit_Constant does not return pa.dataset.Expression # because to support function in() which takes in a List, the elements in the List - # needs to values instead of pc.Expression per pyarrow.dataset.Expression + # needs to values instead of pa.dataset.Expression per pyarrow.dataset.Expression # specification. May be down the road, we can update it as Arrow relaxes this # constraint. def visit_Constant(self, node: ast.Constant) -> object: @@ -187,7 +178,7 @@ def visit_Constant(self, node: ast.Constant) -> object: """ return node.value # Return the constant value directly. - def visit_Call(self, node: ast.Call) -> pc.Expression: + def visit_Call(self, node: ast.Call) -> ds.Expression: """Handle function calls (e.g., is_nan(a), is_valid(b)). Args: diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 82d403ec012d7..60deb88128adf 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -1283,8 +1283,7 @@ def filter( # TODO: (srinathk) bind the expression to the actual schema. # If fn is a string, convert it to a pyarrow.dataset.Expression # Initialize ExpressionEvaluator with valid columns, if available - evaluator = ExpressionEvaluator() - resolved_expr = evaluator.get_filters(expression=expr) + resolved_expr = ExpressionEvaluator.get_filters(expression=expr) compute = TaskPoolStrategy(size=concurrency) else: diff --git a/python/ray/data/tests/test_expression_evaluator.py b/python/ray/data/tests/test_expression_evaluator.py index 08e29a78ad42d..040b42e192466 100644 --- a/python/ray/data/tests/test_expression_evaluator.py +++ b/python/ray/data/tests/test_expression_evaluator.py @@ -33,8 +33,13 @@ def sample_data(tmpdir_factory): "is_student": [False, True, False, False, True, None], # Including a None value } + # Define the schema explicitly + schema = pa.schema( + [("age", pa.float64()), ("city", pa.string()), ("is_student", pa.bool_())] + ) + # Create a PyArrow table from the sample data - table = pa.table(data) + table = pa.table(data, schema=schema) # Use tmpdir_factory to create a temporary directory temp_dir = tmpdir_factory.mktemp("data") @@ -44,7 +49,7 @@ def sample_data(tmpdir_factory): pq.write_table(table, str(parquet_file)) # Yield the path to the Parquet file for testing - yield str(parquet_file) + yield str(parquet_file), schema expressions_and_expected_data = [ @@ -290,13 +295,13 @@ def sample_data(tmpdir_factory): @pytest.mark.parametrize("expression, expected_data", expressions_and_expected_data) def test_filter(sample_data, expression, expected_data): """Test the filter functionality of the ExpressionEvaluator.""" - # Instantiate the ExpressionEvaluator with valid column names - evaluator = ExpressionEvaluator() - filters = evaluator.get_filters(expression) + # Instantiate the ExpressionEvaluator with valid column names + sample_data_path, _ = sample_data + filters = ExpressionEvaluator.get_filters(expression=expression) # Read the table from the Parquet file with the applied filters - filtered_table = pq.read_table(sample_data, filters=filters) + filtered_table = pq.read_table(sample_data_path, filters=filters) # Convert the filtered table back to a list of dictionaries for comparison result = filtered_table.to_pandas().to_dict(orient="records") @@ -314,11 +319,11 @@ def convert_nan_to_none(data): def test_filter_bad_expression(sample_data): - evaluator = ExpressionEvaluator() with pytest.raises(ValueError, match="Invalid syntax in the expression"): - evaluator.get_filters("bad filter") + ExpressionEvaluator.get_filters(expression="bad filter") - filters = evaluator.get_filters("hi > 3") + filters = ExpressionEvaluator.get_filters(expression="hi > 3") + sample_data_path, _ = sample_data with pytest.raises(pa.ArrowInvalid): - pq.read_table(sample_data, filters=filters) + pq.read_table(sample_data_path, filters=filters)