Skip to content

Commit

Permalink
Clean up Expression Evaluator interfaces (#50110)
Browse files Browse the repository at this point in the history
## Why are these changes needed?
Clean up Expression Evaluator interfaces

---------

Signed-off-by: Srinath Krishnamachari <[email protected]>
  • Loading branch information
srinathk10 authored Jan 28, 2025
1 parent e136361 commit 55cfc24
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import ast
import logging
from typing import Any, List, Union

import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.dataset as ds


logger = logging.getLogger(__name__)

Expand All @@ -16,7 +17,8 @@


class ExpressionEvaluator:
def get_filters(self, expression: str) -> pc.Expression:
@staticmethod
def get_filters(expression: str) -> ds.Expression:
"""Parse and evaluate the expression to generate a filter condition.
Args:
Expand All @@ -28,29 +30,16 @@ def get_filters(self, expression: str) -> pc.Expression:
"""
try:
tree = ast.parse(expression, mode="eval")
return self._build_filter_condition(tree.body)
return _ConvertToArrowExpressionVisitor().visit(tree.body)
except SyntaxError as e:
raise ValueError(f"Invalid syntax in the expression: {expression}") from e
except Exception as e:
logger.exception(f"Error processing expression: {e}")
raise

def _build_filter_condition(self, node) -> Union[pc.Expression, List[Any], str]:
"""Recursively evaluate an AST node to build the filter condition.
Args:
node: The AST node to evaluate, representing part of the expression.
Returns:
The evaluated result for the node, which could be a
filter condition, list, or field name.
"""
visitor = _ConvertToArrowExpressionVisitor()
return visitor.visit(node)


class _ConvertToArrowExpressionVisitor(ast.NodeVisitor):
def visit_Compare(self, node: ast.Compare) -> pc.Expression:
def visit_Compare(self, node: ast.Compare) -> ds.Expression:
"""Handle comparison operations (e.g., a == b, a < b, a in b).
Args:
Expand All @@ -62,11 +51,14 @@ def visit_Compare(self, node: ast.Compare) -> pc.Expression:
# Handle left operand
# TODO Validate columns
if isinstance(node.left, ast.Attribute):
left_expr = self.visit(node.left) # Visit and handle attributes
# Visit and handle attributes
left_expr = self.visit(node.left)
elif isinstance(node.left, ast.Name):
left_expr = self.visit(node.left) # Treat as a simple field
# Treat as a simple field
left_expr = self.visit(node.left)
elif isinstance(node.left, ast.Constant):
left_expr = node.left.value # Constant values are used directly
# Constant values are used directly
left_expr = node.left.value
else:
raise ValueError(f"Unsupported left operand type: {type(node.left)}")

Expand All @@ -92,7 +84,7 @@ def visit_Compare(self, node: ast.Compare) -> pc.Expression:
else:
raise ValueError(f"Unsupported operator type: {op}")

def visit_BoolOp(self, node: ast.BoolOp) -> pc.Expression:
def visit_BoolOp(self, node: ast.BoolOp) -> ds.Expression:
"""Handle logical operations (e.g., a and b, a or b).
Args:
Expand All @@ -118,20 +110,19 @@ def visit_BoolOp(self, node: ast.BoolOp) -> pc.Expression:

return combined_expr

def visit_Name(self, node: ast.Name) -> pc.Expression:
"""Handle variable (name) nodes and return them as pc.Expression.
def visit_Name(self, node: ast.Name) -> ds.Expression:
"""Handle variable (name) nodes and return them as pa.dataset.Expression.
Even if the name contains periods, it's treated as a single string.
Args:
node: The AST node representing a variable.
Returns:
The variable wrapped as a pc.Expression.
The variable wrapped as a pa.dataset.Expression.
"""
field_name = (
node.id
) # Directly use the field name as a string (even if it contains periods)
# Directly use the field name as a string (even if it contains periods)
field_name = node.id
return pc.field(field_name)

def visit_Attribute(self, node: ast.Attribute) -> object:
Expand Down Expand Up @@ -159,21 +150,21 @@ def visit_Attribute(self, node: ast.Attribute) -> object:

raise ValueError(f"Unsupported attribute: {node.attr}")

def visit_List(self, node: ast.List) -> pc.Expression:
def visit_List(self, node: ast.List) -> ds.Expression:
"""Handle list literals.
Args:
node: The AST node representing a list.
Returns:
The list of elements wrapped as a pc.Expression.
The list of elements wrapped as a pa.dataset.Expression.
"""
elements = [self.visit(elt) for elt in node.elts]
return pa.array(elements)

# TODO (srinathk) Note that visit_Constant does not return pc.Expression
# TODO (srinathk) Note that visit_Constant does not return pa.dataset.Expression
# because to support function in() which takes in a List, the elements in the List
# needs to values instead of pc.Expression per pyarrow.dataset.Expression
# needs to values instead of pa.dataset.Expression per pyarrow.dataset.Expression
# specification. May be down the road, we can update it as Arrow relaxes this
# constraint.
def visit_Constant(self, node: ast.Constant) -> object:
Expand All @@ -187,7 +178,7 @@ def visit_Constant(self, node: ast.Constant) -> object:
"""
return node.value # Return the constant value directly.

def visit_Call(self, node: ast.Call) -> pc.Expression:
def visit_Call(self, node: ast.Call) -> ds.Expression:
"""Handle function calls (e.g., is_nan(a), is_valid(b)).
Args:
Expand Down
3 changes: 1 addition & 2 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,8 +1283,7 @@ def filter(
# TODO: (srinathk) bind the expression to the actual schema.
# If fn is a string, convert it to a pyarrow.dataset.Expression
# Initialize ExpressionEvaluator with valid columns, if available
evaluator = ExpressionEvaluator()
resolved_expr = evaluator.get_filters(expression=expr)
resolved_expr = ExpressionEvaluator.get_filters(expression=expr)

compute = TaskPoolStrategy(size=concurrency)
else:
Expand Down
25 changes: 15 additions & 10 deletions python/ray/data/tests/test_expression_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,13 @@ def sample_data(tmpdir_factory):
"is_student": [False, True, False, False, True, None], # Including a None value
}

# Define the schema explicitly
schema = pa.schema(
[("age", pa.float64()), ("city", pa.string()), ("is_student", pa.bool_())]
)

# Create a PyArrow table from the sample data
table = pa.table(data)
table = pa.table(data, schema=schema)

# Use tmpdir_factory to create a temporary directory
temp_dir = tmpdir_factory.mktemp("data")
Expand All @@ -44,7 +49,7 @@ def sample_data(tmpdir_factory):
pq.write_table(table, str(parquet_file))

# Yield the path to the Parquet file for testing
yield str(parquet_file)
yield str(parquet_file), schema


expressions_and_expected_data = [
Expand Down Expand Up @@ -290,13 +295,13 @@ def sample_data(tmpdir_factory):
@pytest.mark.parametrize("expression, expected_data", expressions_and_expected_data)
def test_filter(sample_data, expression, expected_data):
"""Test the filter functionality of the ExpressionEvaluator."""
# Instantiate the ExpressionEvaluator with valid column names
evaluator = ExpressionEvaluator()

filters = evaluator.get_filters(expression)
# Instantiate the ExpressionEvaluator with valid column names
sample_data_path, _ = sample_data
filters = ExpressionEvaluator.get_filters(expression=expression)

# Read the table from the Parquet file with the applied filters
filtered_table = pq.read_table(sample_data, filters=filters)
filtered_table = pq.read_table(sample_data_path, filters=filters)

# Convert the filtered table back to a list of dictionaries for comparison
result = filtered_table.to_pandas().to_dict(orient="records")
Expand All @@ -314,11 +319,11 @@ def convert_nan_to_none(data):


def test_filter_bad_expression(sample_data):
evaluator = ExpressionEvaluator()
with pytest.raises(ValueError, match="Invalid syntax in the expression"):
evaluator.get_filters("bad filter")
ExpressionEvaluator.get_filters(expression="bad filter")

filters = evaluator.get_filters("hi > 3")
filters = ExpressionEvaluator.get_filters(expression="hi > 3")

sample_data_path, _ = sample_data
with pytest.raises(pa.ArrowInvalid):
pq.read_table(sample_data, filters=filters)
pq.read_table(sample_data_path, filters=filters)

0 comments on commit 55cfc24

Please sign in to comment.