diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index 6770fccd7dda5..88b4d95ec82cb 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -150,7 +150,11 @@ fn bitwise_coercion(left_type: &DataType, right_type: &DataType) -> Option Option { +/// Get the coerced data type for `eq` or `not eq` operation +pub fn comparison_eq_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { // can't compare dictionaries directly due to // https://github.com/apache/arrow-rs/issues/1201 if lhs_type == rhs_type && !is_dictionary(lhs_type) { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 7d15dd5e99658..346eea472c768 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -33,11 +33,14 @@ use arrow::{ record_batch::RecordBatch, }; +use crate::expressions::try_cast; use crate::{expressions, PhysicalExpr}; use arrow::array::*; use arrow::buffer::{Buffer, MutableBuffer}; use datafusion_common::ScalarValue; +use datafusion_common::ScalarValue::Decimal128; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::binary_rule::comparison_eq_coercion; use datafusion_expr::ColumnarValue; /// Size at which to use a Set rather than Vec for `IN` / `NOT IN` @@ -82,6 +85,8 @@ pub struct InListExpr { /// InSet #[derive(Debug)] pub struct InSet { + // TODO: optimization: In the `IN` or `NOT IN` we don't need to consider the NULL value + // The data type is same, we can use set: HashSet set: HashSet, } @@ -160,6 +165,7 @@ macro_rules! make_contains_primitive { ColumnarValue::Scalar(s) => match s { ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v), ScalarValue::$SCALAR_VALUE(None) => None, + // TODO this is bug, for primitive the expr list should be cast to the same data type ScalarValue::Utf8(None) => None, datatype => unimplemented!("Unexpected type {} for InList", datatype), }, @@ -300,6 +306,90 @@ fn cast_static_filter_to_set(list: &[Arc]) -> HashSet, + negated: bool, +) -> BooleanArray { + let contains_null = list + .iter() + .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null())); + let values = list + .iter() + .flat_map(|v| match v { + ColumnarValue::Scalar(s) => match s { + Decimal128(v128op, _, _) => *v128op, + _ => { + unreachable!( + "InList can't reach other data type for decimal data type." + ) + } + }, + ColumnarValue::Array(_) => { + unimplemented!("InList does not yet support nested columns.") + } + }) + .collect::>(); + + if !negated { + // In + array + .iter() + .map(|v| v.map(|v128| values.contains(&v128))) + .collect::() + } else { + // Not in + if contains_null { + // If the expr is NOT IN and the list contains NULL value + // All the result must be NONE + BooleanArray::from(vec![None; array.len()]) + } else { + array + .iter() + .map(|v| v.map(|v128| !values.contains(&v128))) + .collect::() + } + } +} + +fn make_set_contains_decimal( + array: &DecimalArray, + set: &HashSet, + negated: bool, +) -> BooleanArray { + let contains_null = set.iter().any(|v| v.is_null()); + let native_array = set + .iter() + .flat_map(|v| match v { + Decimal128(v128op, _, _) => *v128op, + _ => { + unreachable!("InList can't reach other data type for decimal data type.") + } + }) + .collect::>(); + let native_set: HashSet = HashSet::from_iter(native_array); + + if !negated { + // In + array + .iter() + .map(|v| v.map(|v128| native_set.contains(&v128))) + .collect::() + } else { + // Not in + if contains_null { + // If the expr is NOT IN and the list contains NULL value + // All the result must be NONE + BooleanArray::from(vec![None; array.len()]) + } else { + array + .iter() + .map(|v| v.map(|v128| !native_set.contains(&v128))) + .collect::() + } + } +} + impl InListExpr { /// Create a new InList expression pub fn new( @@ -504,6 +594,11 @@ impl PhysicalExpr for InListExpr { .unwrap(); set_contains_with_negated!(array, set, self.negated) } + DataType::Decimal(_, _) => { + let array = array.as_any().downcast_ref::().unwrap(); + let result = make_set_contains_decimal(array, set, self.negated); + Ok(ColumnarValue::Array(Arc::new(result))) + } datatype => Result::Err(DataFusionError::NotImplemented(format!( "InSet does not support datatype {:?}.", datatype @@ -631,6 +726,16 @@ impl PhysicalExpr for InListExpr { let null_array = new_null_array(&DataType::Boolean, array.len()); Ok(ColumnarValue::Array(Arc::new(null_array))) } + DataType::Decimal(_, _) => { + let decimal_array = + array.as_any().downcast_ref::().unwrap(); + let result = make_list_contains_decimal( + decimal_array, + list_values, + self.negated, + ); + Ok(ColumnarValue::Array(Arc::new(result))) + } datatype => Result::Err(DataFusionError::NotImplemented(format!( "InList does not support datatype {:?}.", datatype @@ -640,13 +745,63 @@ impl PhysicalExpr for InListExpr { } } +type InListCastResult = (Arc, Vec>); + /// Creates a unary expression InList pub fn in_list( expr: Arc, list: Vec>, negated: &bool, + input_schema: &Schema, ) -> Result> { - Ok(Arc::new(InListExpr::new(expr, list, *negated))) + let (cast_expr, cast_list) = in_list_cast(expr, list, input_schema)?; + Ok(Arc::new(InListExpr::new(cast_expr, cast_list, *negated))) +} + +fn in_list_cast( + expr: Arc, + list: Vec>, + input_schema: &Schema, +) -> Result { + let expr_type = &expr.data_type(input_schema)?; + let list_types: Vec = list + .iter() + .map(|list_expr| list_expr.data_type(input_schema).unwrap()) + .collect(); + // TODO in the arrow-rs, should support NULL type to Decimal Data type + // TODO support in the arrow-rs, NULL value cast to Decimal Value + // https://github.com/apache/arrow-datafusion/issues/2759 + let result_type = get_coerce_type(expr_type, &list_types); + match result_type { + None => Err(DataFusionError::Internal(format!( + "In expr can find the coerced type for {:?} in {:?}", + expr_type, list_types + ))), + Some(data_type) => { + // find the coerced type + let cast_expr = try_cast(expr, input_schema, data_type.clone())?; + let cast_list_expr = list + .into_iter() + .map(|list_expr| { + try_cast(list_expr, input_schema, data_type.clone()).unwrap() + }) + .collect(); + Ok((cast_expr, cast_list_expr)) + } + } +} + +fn get_coerce_type(expr_type: &DataType, list_type: &[DataType]) -> Option { + // get the equal coerced data type + list_type + .iter() + .fold(Some(expr_type.clone()), |left, right_type| { + match left { + None => None, + // TODO refactor a framework to do the data type coercion + Some(left_type) => comparison_eq_coercion(&left_type, right_type), + } + }) } #[cfg(test)] @@ -659,8 +814,8 @@ mod tests { // applies the in_list expr to an input batch and list macro_rules! in_list { - ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr) => {{ - let expr = in_list($COL, $LIST, $NEGATED).unwrap(); + ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr, $SCHEMA:expr) => {{ + let expr = in_list($COL, $LIST, $NEGATED, $SCHEMA).unwrap(); let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows()); let result = result .as_any() @@ -676,7 +831,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); let a = StringArray::from(vec![Some("a"), Some("d"), None]); let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // expression: "a in ("a", "b")" let list = vec![ @@ -688,7 +843,8 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in ("a", "b")" @@ -701,7 +857,8 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in ("a", "b")" @@ -715,7 +872,8 @@ mod tests { list, &false, vec![Some(true), None, None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in ("a", "b")" @@ -729,7 +887,8 @@ mod tests { list, &true, vec![Some(false), None, None], - col_a.clone() + col_a.clone(), + &schema ); Ok(()) @@ -740,7 +899,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]); let a = Int64Array::from(vec![Some(0), Some(2), None]); let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // expression: "a in (0, 1)" let list = vec![ @@ -752,7 +911,8 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in (0, 1)" @@ -765,35 +925,38 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a in (0, 1, NULL)" let list = vec![ lit(ScalarValue::Int64(Some(0))), lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Null), ]; in_list!( batch, list, &false, vec![Some(true), None, None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in (0, 1, NULL)" let list = vec![ lit(ScalarValue::Int64(Some(0))), lit(ScalarValue::Int64(Some(1))), - lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Null), ]; in_list!( batch, list, &true, vec![Some(false), None, None], - col_a.clone() + col_a.clone(), + &schema ); Ok(()) @@ -804,7 +967,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]); let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // expression: "a in (0.0, 0.2)" let list = vec![ @@ -816,7 +979,8 @@ mod tests { list, &false, vec![Some(true), Some(false), None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in (0.0, 0.2)" @@ -829,35 +993,38 @@ mod tests { list, &true, vec![Some(false), Some(true), None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a in (0.0, 0.2, NULL)" let list = vec![ lit(ScalarValue::Float64(Some(0.0))), lit(ScalarValue::Float64(Some(0.1))), - lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Null), ]; in_list!( batch, list, &false, vec![Some(true), None, None], - col_a.clone() + col_a.clone(), + &schema ); // expression: "a not in (0.0, 0.2, NULL)" let list = vec![ lit(ScalarValue::Float64(Some(0.0))), lit(ScalarValue::Float64(Some(0.1))), - lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Null), ]; in_list!( batch, list, &true, vec![Some(false), None, None], - col_a.clone() + col_a.clone(), + &schema ); Ok(()) @@ -868,29 +1035,157 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]); let a = BooleanArray::from(vec![Some(true), None]); let col_a = col("a", &schema)?; - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; // expression: "a in (true)" let list = vec![lit(ScalarValue::Boolean(Some(true)))]; - in_list!(batch, list, &false, vec![Some(true), None], col_a.clone()); + in_list!( + batch, + list, + &false, + vec![Some(true), None], + col_a.clone(), + &schema + ); // expression: "a not in (true)" let list = vec![lit(ScalarValue::Boolean(Some(true)))]; - in_list!(batch, list, &true, vec![Some(false), None], col_a.clone()); + in_list!( + batch, + list, + &true, + vec![Some(false), None], + col_a.clone(), + &schema + ); // expression: "a in (true, NULL)" let list = vec![ lit(ScalarValue::Boolean(Some(true))), - lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Null), ]; - in_list!(batch, list, &false, vec![Some(true), None], col_a.clone()); + in_list!( + batch, + list, + &false, + vec![Some(true), None], + col_a.clone(), + &schema + ); // expression: "a not in (true, NULL)" let list = vec![ lit(ScalarValue::Boolean(Some(true))), - lit(ScalarValue::Utf8(None)), + lit(ScalarValue::Null), + ]; + in_list!( + batch, + list, + &true, + vec![Some(false), None], + col_a.clone(), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_decimal() -> Result<()> { + // Now, we can check the NULL type + let schema = Schema::new(vec![Field::new("a", DataType::Decimal(13, 4), true)]); + let array = vec![Some(100_0000_i128), None, Some(200_5000_i128)] + .into_iter() + .collect::(); + let array = array.with_precision_and_scale(13, 4).unwrap(); + let col_a = col("a", &schema)?; + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?; + + // expression: "a in (100,200), the data type of list is INT32 + let list = vec![ + lit(ScalarValue::Int32(Some(100))), + lit(ScalarValue::Int32(Some(200))), + ]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, Some(false)], + col_a.clone(), + &schema + ); + // expression: "a not in (100,200) + let list = vec![ + lit(ScalarValue::Int32(Some(100))), + lit(ScalarValue::Int32(Some(200))), + ]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, Some(true)], + col_a.clone(), + &schema + ); + + // expression: "a in (200,NULL), the data type of list is INT32 AND NULL + // TODO support: NULL data type to decimal in arrow-rs + // let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)]; + // in_list!(batch, list, &false, vec![Some(true), None, Some(false)], col_a.clone(), &schema); + + // expression: "a in (200.5, 100), the data type of list is FLOAT32 and INT32 + let list = vec![ + lit(ScalarValue::Float32(Some(200.50f32))), + lit(ScalarValue::Int32(Some(100))), ]; - in_list!(batch, list, &true, vec![Some(false), None], col_a.clone()); + in_list!( + batch, + list, + &false, + vec![Some(true), None, Some(true)], + col_a.clone(), + &schema + ); + + // expression: "a not in (200.5, 100), the data type of list is FLOAT32 and INT32 + let list = vec![ + lit(ScalarValue::Float32(Some(200.50f32))), + lit(ScalarValue::Int32(Some(101))), + ]; + in_list!( + batch, + list, + &true, + vec![Some(true), None, Some(false)], + col_a.clone(), + &schema + ); + + // test the optimization: set + // expression: "a in (99..300), the data type of list is INT32 + let list = (99..300) + .into_iter() + .map(|v| lit(ScalarValue::Int32(Some(v)))) + .collect::>(); + + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, Some(false)], + col_a.clone(), + &schema + ); + + in_list!( + batch, + list, + &true, + vec![Some(false), None, Some(true)], + col_a.clone(), + &schema + ); Ok(()) } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 92580fce01091..26583cd28a16a 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -294,6 +294,9 @@ pub fn create_physical_expr( input_schema, execution_props, ), + // TODO refactor the logic of coercion the data type + // data type in the `list expr` may be conflict with `value expr`, + // we should not just compare data type between `value expr` with each `list expr`. _ => { let list_expr = create_physical_expr( expr, @@ -310,6 +313,8 @@ pub fn create_physical_expr( &list_expr_data_type, &value_expr_data_type, ) { + // TODO: Can't cast from list type to value type directly + // We should use the coercion rule to get the common data type expressions::cast( list_expr, input_schema, @@ -325,7 +330,7 @@ pub fn create_physical_expr( }) .collect::>>()?; - expressions::in_list(value_expr, list_exprs, negated) + expressions::in_list(value_expr, list_exprs, negated, input_schema) } }, other => Err(DataFusionError::NotImplemented(format!(