diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index acf21a5b17..76befb6d8d 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -132,7 +132,16 @@ impl Bind for UnaryExpression { } impl UnaryExpression { - pub(crate) fn new(op: PredicateOperator, term: T) -> Self { + /// Creates a unary expression with the given operator and term. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::{PredicateOperator, Reference, UnaryExpression}; + /// + /// UnaryExpression::new(PredicateOperator::IsNull, Reference::new("c")); + /// ``` + pub fn new(op: PredicateOperator, term: T) -> Self { debug_assert!(op.is_unary()); Self { op, term } } @@ -171,7 +180,21 @@ impl Debug for BinaryExpression { } impl BinaryExpression { - pub(crate) fn new(op: PredicateOperator, term: T, literal: Datum) -> Self { + /// Creates a binary expression with the given operator, term and literal. + /// + /// # Example + /// + /// ```rust + /// use iceberg::expr::{BinaryExpression, PredicateOperator, Reference}; + /// use iceberg::spec::Datum; + /// + /// BinaryExpression::new( + /// PredicateOperator::LessThanOrEq, + /// Reference::new("a"), + /// Datum::int(10), + /// ); + /// ``` + pub fn new(op: PredicateOperator, term: T, literal: Datum) -> Self { debug_assert!(op.is_binary()); Self { op, term, literal } } diff --git a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs index 110e4f7e4f..f438308e68 100644 --- a/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs +++ b/crates/integrations/datafusion/src/physical_plan/expr_to_predicate.rs @@ -15,103 +15,179 @@ // specific language governing permissions and limitations // under the License. -use std::collections::VecDeque; +use std::vec; -use datafusion::common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; -use datafusion::common::Column; -use datafusion::error::DataFusionError; use datafusion::logical_expr::{Expr, Operator}; use datafusion::scalar::ScalarValue; -use iceberg::expr::{Predicate, Reference}; +use iceberg::expr::{BinaryExpression, Predicate, PredicateOperator, Reference, UnaryExpression}; use iceberg::spec::Datum; -pub struct ExprToPredicateVisitor { - stack: VecDeque>, +// A datafusion expression could be an Iceberg predicate, column, or literal. +enum TransformedResult { + Predicate(Predicate), + Column(Reference), + Literal(Datum), + NotTransformed, } -impl ExprToPredicateVisitor { - /// Create a new predicate conversion visitor. - pub fn new() -> Self { - Self { - stack: VecDeque::new(), + +enum OpTransformedResult { + Operator(PredicateOperator), + And, + Or, + NotTransformed, +} + +/// Converts DataFusion filters ([`Expr`]) to an iceberg [`Predicate`]. +/// If none of the filters could be converted, return `None` which adds no predicates to the scan operation. +/// If the conversion was successful, return the converted predicates combined with an AND operator. +pub fn convert_filters_to_predicate(filters: &[Expr]) -> Option { + filters + .iter() + .filter_map(convert_filter_to_predicate) + .reduce(Predicate::and) +} + +fn convert_filter_to_predicate(expr: &Expr) -> Option { + match to_iceberg_predicate(expr) { + TransformedResult::Predicate(predicate) => Some(predicate), + TransformedResult::Column(_) | TransformedResult::Literal(_) => { + unreachable!("Not a valid expression: {:?}", expr) } + _ => None, } - /// Get the predicate from the stack. - pub fn get_predicate(&self) -> Option { - self.stack - .iter() - .filter_map(|opt| opt.clone()) - .reduce(Predicate::and) +} + +fn to_iceberg_predicate(expr: &Expr) -> TransformedResult { + match expr { + Expr::BinaryExpr(binary) => { + let left = to_iceberg_predicate(&binary.left); + let right = to_iceberg_predicate(&binary.right); + let op = to_iceberg_operation(binary.op); + match op { + OpTransformedResult::Operator(op) => to_iceberg_binary_predicate(left, right, op), + OpTransformedResult::And => to_iceberg_and_predicate(left, right), + OpTransformedResult::Or => to_iceberg_or_predicate(left, right), + OpTransformedResult::NotTransformed => TransformedResult::NotTransformed, + } + } + Expr::Not(exp) => { + let expr = to_iceberg_predicate(exp); + match expr { + TransformedResult::Predicate(p) => TransformedResult::Predicate(!p), + _ => TransformedResult::NotTransformed, + } + } + Expr::Column(column) => TransformedResult::Column(Reference::new(column.name())), + Expr::Literal(literal) => match scalar_value_to_datum(literal) { + Some(data) => TransformedResult::Literal(data), + None => TransformedResult::NotTransformed, + }, + Expr::InList(inlist) => { + let mut datums = vec![]; + for expr in &inlist.list { + let p = to_iceberg_predicate(expr); + match p { + TransformedResult::Literal(l) => datums.push(l), + _ => return TransformedResult::NotTransformed, + } + } + + let expr = to_iceberg_predicate(&inlist.expr); + match expr { + TransformedResult::Column(r) => match inlist.negated { + false => TransformedResult::Predicate(r.is_in(datums)), + true => TransformedResult::Predicate(r.is_not_in(datums)), + }, + _ => TransformedResult::NotTransformed, + } + } + Expr::IsNull(expr) => { + let p = to_iceberg_predicate(expr); + match p { + TransformedResult::Column(r) => TransformedResult::Predicate(Predicate::Unary( + UnaryExpression::new(PredicateOperator::IsNull, r), + )), + _ => TransformedResult::NotTransformed, + } + } + Expr::IsNotNull(expr) => { + let p = to_iceberg_predicate(expr); + match p { + TransformedResult::Column(r) => TransformedResult::Predicate(Predicate::Unary( + UnaryExpression::new(PredicateOperator::NotNull, r), + )), + _ => TransformedResult::NotTransformed, + } + } + _ => TransformedResult::NotTransformed, } +} - /// Convert a column expression to an iceberg predicate. - fn convert_column_expr( - &self, - col: &Column, - op: &Operator, - lit: &ScalarValue, - ) -> Option { - let reference = Reference::new(col.name.clone()); - let datum = scalar_value_to_datum(lit)?; - Some(binary_op_to_predicate(reference, op, datum)) +fn to_iceberg_operation(op: Operator) -> OpTransformedResult { + match op { + Operator::Eq => OpTransformedResult::Operator(PredicateOperator::Eq), + Operator::NotEq => OpTransformedResult::Operator(PredicateOperator::NotEq), + Operator::Lt => OpTransformedResult::Operator(PredicateOperator::LessThan), + Operator::LtEq => OpTransformedResult::Operator(PredicateOperator::LessThanOrEq), + Operator::Gt => OpTransformedResult::Operator(PredicateOperator::GreaterThan), + Operator::GtEq => OpTransformedResult::Operator(PredicateOperator::GreaterThanOrEq), + // AND OR + Operator::And => OpTransformedResult::And, + Operator::Or => OpTransformedResult::Or, + // Others not supported + _ => OpTransformedResult::NotTransformed, } +} - /// Convert a compound expression to an iceberg predicate. - /// - /// The strategy is to support the following cases: - /// - if its an AND expression then the result will be the valid predicates, whether there are 2 or just 1 - /// - if its an OR expression then a predicate will be returned only if there are 2 valid predicates on both sides - fn convert_compound_expr(&self, valid_preds: &[Predicate], op: &Operator) -> Option { - let valid_preds_count = valid_preds.len(); - match (op, valid_preds_count) { - (Operator::And, 1) => valid_preds.first().cloned(), - (Operator::And, 2) => Some(Predicate::and( - valid_preds[0].clone(), - valid_preds[1].clone(), - )), - (Operator::Or, 2) => Some(Predicate::or( - valid_preds[0].clone(), - valid_preds[1].clone(), - )), - _ => None, +fn to_iceberg_and_predicate( + left: TransformedResult, + right: TransformedResult, +) -> TransformedResult { + match (left, right) { + (TransformedResult::Predicate(left), TransformedResult::Predicate(right)) => { + TransformedResult::Predicate(left.and(right)) } + (TransformedResult::Predicate(left), _) => TransformedResult::Predicate(left), + (_, TransformedResult::Predicate(right)) => TransformedResult::Predicate(right), + _ => TransformedResult::NotTransformed, } } -// Implement TreeNodeVisitor for ExprToPredicateVisitor -impl<'n> TreeNodeVisitor<'n> for ExprToPredicateVisitor { - type Node = Expr; - - fn f_down(&mut self, _node: &'n Expr) -> Result { - Ok(TreeNodeRecursion::Continue) +fn to_iceberg_or_predicate(left: TransformedResult, right: TransformedResult) -> TransformedResult { + match (left, right) { + (TransformedResult::Predicate(left), TransformedResult::Predicate(right)) => { + TransformedResult::Predicate(left.or(right)) + } + _ => TransformedResult::NotTransformed, } +} - fn f_up(&mut self, expr: &'n Expr) -> Result { - if let Expr::BinaryExpr(binary) = expr { - match (&*binary.left, &binary.op, &*binary.right) { - // process simple binary expressions, e.g. col > 1 - (Expr::Column(col), op, Expr::Literal(lit)) => { - let col_pred = self.convert_column_expr(col, op, lit); - self.stack.push_back(col_pred); - } - // // process reversed binary expressions, e.g. 1 < col - (Expr::Literal(lit), op, Expr::Column(col)) => { - let col_pred = op - .swap() - .and_then(|negated_op| self.convert_column_expr(col, &negated_op, lit)); - self.stack.push_back(col_pred); - } - // process compound expressions (involving logical operators. e.g., AND or OR and children) - (_left, op, _right) if op.is_logic_operator() => { - let right_pred = self.stack.pop_back().flatten(); - let left_pred = self.stack.pop_back().flatten(); - let children: Vec<_> = [left_pred, right_pred].into_iter().flatten().collect(); - let compound_pred = self.convert_compound_expr(&children, op); - self.stack.push_back(compound_pred); - } - _ => return Ok(TreeNodeRecursion::Continue), - } +fn to_iceberg_binary_predicate( + left: TransformedResult, + right: TransformedResult, + op: PredicateOperator, +) -> TransformedResult { + let (r, d, op) = match (left, right) { + (TransformedResult::NotTransformed, _) => return TransformedResult::NotTransformed, + (_, TransformedResult::NotTransformed) => return TransformedResult::NotTransformed, + (TransformedResult::Column(r), TransformedResult::Literal(d)) => (r, d, op), + (TransformedResult::Literal(d), TransformedResult::Column(r)) => { + (r, d, reverse_predicate_operator(op)) } - Ok(TreeNodeRecursion::Continue) + _ => return TransformedResult::NotTransformed, + }; + TransformedResult::Predicate(Predicate::Binary(BinaryExpression::new(op, r, d))) +} + +fn reverse_predicate_operator(op: PredicateOperator) -> PredicateOperator { + match op { + PredicateOperator::Eq => PredicateOperator::Eq, + PredicateOperator::NotEq => PredicateOperator::NotEq, + PredicateOperator::GreaterThan => PredicateOperator::LessThan, + PredicateOperator::GreaterThanOrEq => PredicateOperator::LessThanOrEq, + PredicateOperator::LessThan => PredicateOperator::GreaterThan, + PredicateOperator::LessThanOrEq => PredicateOperator::GreaterThanOrEq, + _ => unreachable!("Reverse {}", op), } } @@ -133,93 +209,113 @@ fn scalar_value_to_datum(value: &ScalarValue) -> Option { } } -/// convert the data fusion Exp to an iceberg [`Predicate`] -fn binary_op_to_predicate(reference: Reference, op: &Operator, datum: Datum) -> Predicate { - match op { - Operator::Eq => reference.equal_to(datum), - Operator::NotEq => reference.not_equal_to(datum), - Operator::Lt => reference.less_than(datum), - Operator::LtEq => reference.less_than_or_equal_to(datum), - Operator::Gt => reference.greater_than(datum), - Operator::GtEq => reference.greater_than_or_equal_to(datum), - _ => Predicate::AlwaysTrue, - } -} - #[cfg(test)] mod tests { - use std::collections::VecDeque; - use datafusion::arrow::datatypes::{DataType, Field, Schema}; - use datafusion::common::tree_node::TreeNode; use datafusion::common::DFSchema; - use datafusion::prelude::SessionContext; + use datafusion::logical_expr::utils::split_conjunction; + use datafusion::prelude::{Expr, SessionContext}; use iceberg::expr::{Predicate, Reference}; use iceberg::spec::Datum; - use super::ExprToPredicateVisitor; + use super::convert_filters_to_predicate; fn create_test_schema() -> DFSchema { let arrow_schema = Schema::new(vec![ - Field::new("foo", DataType::Int32, false), - Field::new("bar", DataType::Utf8, false), + Field::new("foo", DataType::Int32, true), + Field::new("bar", DataType::Utf8, true), ]); DFSchema::try_from_qualified_schema("my_table", &arrow_schema).unwrap() } - #[test] - fn test_predicate_conversion_with_single_condition() { - let sql = "foo > 1"; + fn convert_to_iceberg_predicate(sql: &str) -> Option { let df_schema = create_test_schema(); let expr = SessionContext::new() .parse_sql_expr(sql, &df_schema) .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate().unwrap(); + let exprs: Vec = split_conjunction(&expr).into_iter().cloned().collect(); + convert_filters_to_predicate(&exprs[..]) + } + + #[test] + fn test_predicate_conversion_with_single_condition() { + let predicate = convert_to_iceberg_predicate("foo = 1").unwrap(); + assert_eq!(predicate, Reference::new("foo").equal_to(Datum::long(1))); + + let predicate = convert_to_iceberg_predicate("foo != 1").unwrap(); + assert_eq!( + predicate, + Reference::new("foo").not_equal_to(Datum::long(1)) + ); + + let predicate = convert_to_iceberg_predicate("foo > 1").unwrap(); assert_eq!( predicate, Reference::new("foo").greater_than(Datum::long(1)) ); + + let predicate = convert_to_iceberg_predicate("foo >= 1").unwrap(); + assert_eq!( + predicate, + Reference::new("foo").greater_than_or_equal_to(Datum::long(1)) + ); + + let predicate = convert_to_iceberg_predicate("foo < 1").unwrap(); + assert_eq!(predicate, Reference::new("foo").less_than(Datum::long(1))); + + let predicate = convert_to_iceberg_predicate("foo <= 1").unwrap(); + assert_eq!( + predicate, + Reference::new("foo").less_than_or_equal_to(Datum::long(1)) + ); + + let predicate = convert_to_iceberg_predicate("foo is null").unwrap(); + assert_eq!(predicate, Reference::new("foo").is_null()); + + let predicate = convert_to_iceberg_predicate("foo is not null").unwrap(); + assert_eq!(predicate, Reference::new("foo").is_not_null()); + + let predicate = convert_to_iceberg_predicate("foo in (5, 6)").unwrap(); + assert_eq!( + predicate, + Reference::new("foo").is_in([Datum::long(5), Datum::long(6)]) + ); + + let predicate = convert_to_iceberg_predicate("foo not in (5, 6)").unwrap(); + assert_eq!( + predicate, + Reference::new("foo").is_not_in([Datum::long(5), Datum::long(6)]) + ); + + let predicate = convert_to_iceberg_predicate("not foo = 1").unwrap(); + assert_eq!(predicate, !Reference::new("foo").equal_to(Datum::long(1))); } + #[test] fn test_predicate_conversion_with_single_unsupported_condition() { - let sql = "foo is null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate(); + let predicate = convert_to_iceberg_predicate("foo + 1 = 1"); + assert_eq!(predicate, None); + + let predicate = convert_to_iceberg_predicate("length(bar) = 1"); + assert_eq!(predicate, None); + + let predicate = convert_to_iceberg_predicate("foo in (1, 2, foo)"); assert_eq!(predicate, None); } #[test] fn test_predicate_conversion_with_single_condition_rev() { - let sql = "1 < foo"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate().unwrap(); + let predicate = convert_to_iceberg_predicate("1 < foo").unwrap(); assert_eq!( predicate, Reference::new("foo").greater_than(Datum::long(1)) ); } + #[test] fn test_predicate_conversion_with_and_condition() { let sql = "foo > 1 and bar = 'test'"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate().unwrap(); + let predicate = convert_to_iceberg_predicate(sql).unwrap(); let expected_predicate = Predicate::and( Reference::new("foo").greater_than(Datum::long(1)), Reference::new("bar").equal_to(Datum::string("test")), @@ -229,55 +325,42 @@ mod tests { #[test] fn test_predicate_conversion_with_and_condition_unsupported() { - let sql = "foo > 1 and bar is not null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate().unwrap(); + let sql = "foo > 1 and length(bar) = 1"; + let predicate = convert_to_iceberg_predicate(sql).unwrap(); let expected_predicate = Reference::new("foo").greater_than(Datum::long(1)); assert_eq!(predicate, expected_predicate); } + #[test] fn test_predicate_conversion_with_and_condition_both_unsupported() { - let sql = "foo in (1, 2, 3) and bar is not null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate(); - let expected_predicate = None; - assert_eq!(predicate, expected_predicate); + let sql = "foo in (1, 2, foo) and length(bar) = 1"; + let predicate = convert_to_iceberg_predicate(sql); + assert_eq!(predicate, None); } #[test] fn test_predicate_conversion_with_or_condition_unsupported() { - let sql = "foo > 1 or bar is not null"; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate(); - let expected_predicate = None; + let sql = "foo > 1 or length(bar) = 1"; + let predicate = convert_to_iceberg_predicate(sql); + assert_eq!(predicate, None); + } + + #[test] + fn test_predicate_conversion_with_or_condition_supported() { + let sql = "foo > 1 or bar = 'test'"; + let predicate = convert_to_iceberg_predicate(sql).unwrap(); + let expected_predicate = Predicate::or( + Reference::new("foo").greater_than(Datum::long(1)), + Reference::new("bar").equal_to(Datum::string("test")), + ); assert_eq!(predicate, expected_predicate); } #[test] fn test_predicate_conversion_with_complex_binary_expr() { let sql = "(foo > 1 and bar = 'test') or foo < 0 "; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate().unwrap(); + let predicate = convert_to_iceberg_predicate(sql).unwrap(); + let inner_predicate = Predicate::and( Reference::new("foo").greater_than(Datum::long(1)), Reference::new("bar").equal_to(Datum::string("test")), @@ -290,46 +373,23 @@ mod tests { } #[test] - fn test_predicate_conversion_with_complex_binary_expr_unsupported() { - let sql = "(foo > 1 or bar in ('test', 'test2')) and foo < 0 "; - let df_schema = create_test_schema(); - let expr = SessionContext::new() - .parse_sql_expr(sql, &df_schema) - .unwrap(); - let mut visitor = ExprToPredicateVisitor::new(); - expr.visit(&mut visitor).unwrap(); - let predicate = visitor.get_predicate().unwrap(); - let expected_predicate = Reference::new("foo").less_than(Datum::long(0)); - assert_eq!(predicate, expected_predicate); - } + fn test_predicate_conversion_with_one_and_expr_supported() { + let sql = "(foo > 1 and length(bar) = 1 ) or foo < 0 "; + let predicate = convert_to_iceberg_predicate(sql).unwrap(); - #[test] - // test the get result method - fn test_get_result_multiple() { - let predicates = vec![ - Some(Reference::new("foo").greater_than(Datum::long(1))), - None, - Some(Reference::new("bar").equal_to(Datum::string("test"))), - ]; - let stack = VecDeque::from(predicates); - let visitor = ExprToPredicateVisitor { stack }; - assert_eq!( - visitor.get_predicate(), - Some(Predicate::and( - Reference::new("foo").greater_than(Datum::long(1)), - Reference::new("bar").equal_to(Datum::string("test")), - )) + let inner_predicate = Reference::new("foo").greater_than(Datum::long(1)); + let expected_predicate = Predicate::or( + inner_predicate, + Reference::new("foo").less_than(Datum::long(0)), ); + assert_eq!(predicate, expected_predicate); } #[test] - fn test_get_result_single() { - let predicates = vec![Some(Reference::new("foo").greater_than(Datum::long(1)))]; - let stack = VecDeque::from(predicates); - let visitor = ExprToPredicateVisitor { stack }; - assert_eq!( - visitor.get_predicate(), - Some(Reference::new("foo").greater_than(Datum::long(1))) - ); + fn test_predicate_conversion_with_complex_binary_expr_unsupported() { + let sql = "(foo > 1 or length(bar) = 1 ) and foo < 0 "; + let predicate = convert_to_iceberg_predicate(sql).unwrap(); + let expected_predicate = Reference::new("foo").less_than(Datum::long(0)); + assert_eq!(predicate, expected_predicate); } } diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index c53ce76d50..59cf099765 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -22,7 +22,6 @@ use std::vec; use datafusion::arrow::array::RecordBatch; use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; -use datafusion::common::tree_node::TreeNode; use datafusion::error::Result as DFResult; use datafusion::execution::{SendableRecordBatchStream, TaskContext}; use datafusion::physical_expr::EquivalenceProperties; @@ -35,7 +34,7 @@ use futures::{Stream, TryStreamExt}; use iceberg::expr::Predicate; use iceberg::table::Table; -use crate::physical_plan::expr_to_predicate::ExprToPredicateVisitor; +use super::expr_to_predicate::convert_filters_to_predicate; use crate::to_datafusion_error; /// Manages the scanning process of an Iceberg [`Table`], encapsulating the @@ -140,10 +139,13 @@ impl DisplayAs for IcebergTableScan { ) -> std::fmt::Result { write!( f, - "IcebergTableScan projection:[{}]", + "IcebergTableScan projection:[{}] predicate:[{}]", self.projection .clone() - .map_or(String::new(), |v| v.join(",")) + .map_or(String::new(), |v| v.join(",")), + self.predicates + .clone() + .map_or(String::from(""), |p| format!("{}", p)) ) } } @@ -175,22 +177,6 @@ async fn get_batch_stream( Ok(Box::pin(stream)) } -/// Converts DataFusion filters ([`Expr`]) to an iceberg [`Predicate`]. -/// If none of the filters could be converted, return `None` which adds no predicates to the scan operation. -/// If the conversion was successful, return the converted predicates combined with an AND operator. -fn convert_filters_to_predicate(filters: &[Expr]) -> Option { - filters - .iter() - .filter_map(|expr| { - let mut visitor = ExprToPredicateVisitor::new(); - if expr.visit(&mut visitor).is_ok() { - visitor.get_predicate() - } else { - None - } - }) - .reduce(Predicate::and) -} fn get_column_names( schema: ArrowSchemaRef, projection: Option<&Vec>, diff --git a/crates/integrations/datafusion/src/table.rs b/crates/integrations/datafusion/src/table.rs index 2797e12d67..bb24713aa0 100644 --- a/crates/integrations/datafusion/src/table.rs +++ b/crates/integrations/datafusion/src/table.rs @@ -23,7 +23,7 @@ use datafusion::arrow::datatypes::SchemaRef as ArrowSchemaRef; use datafusion::catalog::Session; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result as DFResult; -use datafusion::logical_expr::{BinaryExpr, Expr, TableProviderFilterPushDown}; +use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::ExecutionPlan; use iceberg::arrow::schema_to_arrow_schema; use iceberg::table::Table; @@ -99,15 +99,8 @@ impl TableProvider for IcebergTableProvider { filters: &[&Expr], ) -> std::result::Result, datafusion::error::DataFusionError> { - let filter_support = filters - .iter() - .map(|e| match e { - Expr::BinaryExpr(BinaryExpr { .. }) => TableProviderFilterPushDown::Inexact, - _ => TableProviderFilterPushDown::Unsupported, - }) - .collect::>(); - - Ok(filter_support) + // Push down all filters, as a single source of truth, the scanner will drop the filters which couldn't be push down + Ok(vec![TableProviderFilterPushDown::Inexact; filters.len()]) } } diff --git a/crates/integrations/datafusion/tests/integration_datafusion_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_test.rs index d6e22d0440..d320c8ef08 100644 --- a/crates/integrations/datafusion/tests/integration_datafusion_test.rs +++ b/crates/integrations/datafusion/tests/integration_datafusion_test.rs @@ -204,10 +204,7 @@ async fn test_table_projection() -> Result<()> { .unwrap(); assert_eq!(2, s.len()); // the first row is logical_plan, the second row is physical_plan - assert_eq!( - "IcebergTableScan projection:[foo1,foo2,foo3]", - s.value(1).trim() - ); + assert!(s.value(1).contains("projection:[foo1,foo2,foo3]")); // datafusion doesn't support query foo3.s_foo1, use foo3 instead let records = table_df @@ -226,7 +223,54 @@ async fn test_table_projection() -> Result<()> { .downcast_ref::() .unwrap(); assert_eq!(2, s.len()); - assert_eq!("IcebergTableScan projection:[foo1,foo3]", s.value(1).trim()); + assert!(s + .value(1) + .contains("IcebergTableScan projection:[foo1,foo3]")); Ok(()) } + +#[tokio::test] +async fn test_table_predict_pushdown() -> Result<()> { + let iceberg_catalog = get_iceberg_catalog(); + let namespace = NamespaceIdent::new("ns".to_string()); + set_test_namespace(&iceberg_catalog, &namespace).await?; + + let schema = Schema::builder() + .with_schema_id(0) + .with_fields(vec![ + NestedField::required(1, "foo", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(2, "bar", Type::Primitive(PrimitiveType::String)).into(), + ]) + .build()?; + let creation = get_table_creation(temp_path(), "t1", Some(schema))?; + iceberg_catalog.create_table(&namespace, creation).await?; + + let client = Arc::new(iceberg_catalog); + let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); + + let ctx = SessionContext::new(); + ctx.register_catalog("catalog", catalog); + let records = ctx + .sql("select * from catalog.ns.t1 where (foo > 1 and length(bar) = 1 ) or bar is null") + .await + .unwrap() + .explain(false, false) + .unwrap() + .collect() + .await + .unwrap(); + assert_eq!(1, records.len()); + let record = &records[0]; + // the first column is plan_type, the second column plan string. + let s = record + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(2, s.len()); + // the first row is logical_plan, the second row is physical_plan + let expected = "predicate:[(foo > 1) OR (bar IS NULL)]"; + assert!(s.value(1).trim().contains(expected)); + Ok(()) +}