From 69d168e37a7b3551bb8a2e1307435d37f80aa109 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Sat, 31 Aug 2024 11:14:56 +0800 Subject: [PATCH 1/4] feat: scalar regex match physical expr --- .../physical-expr/src/expressions/mod.rs | 2 + .../src/expressions/scalar_regex_match.rs | 670 ++++++++++++++++++ datafusion/physical-expr/src/planner.rs | 28 +- datafusion/proto/proto/datafusion.proto | 9 + datafusion/proto/src/generated/pbjson.rs | 157 ++++ datafusion/proto/src/generated/prost.rs | 18 +- .../proto/src/physical_plan/from_proto.rs | 22 +- .../proto/src/physical_plan/to_proto.rs | 21 +- 8 files changed, 923 insertions(+), 4 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/scalar_regex_match.rs diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 7d71bd9ff17b..a5a59399191f 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -30,6 +30,7 @@ mod literal; mod negative; mod no_op; mod not; +mod scalar_regex_match; mod try_cast; mod unknown_column; @@ -51,5 +52,6 @@ pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; pub use no_op::NoOp; pub use not::{not, NotExpr}; +pub use scalar_regex_match::{scalar_regex_match, ScalarRegexMatchExpr}; pub use try_cast::{try_cast, TryCastExpr}; pub use unknown_column::UnKnownColumn; diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs new file mode 100644 index 000000000000..badb00659576 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -0,0 +1,670 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::Literal; +use arrow::array::ArrayData; +use arrow_array::{ + Array, ArrayAccessor, BooleanArray, LargeStringArray, StringArray, StringViewArray, +}; +use arrow_buffer::BooleanBufferBuilder; +use arrow_schema::{DataType, Schema}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr_common::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use regex::Regex; +use std::{any::Any, hash::Hash, sync::Arc}; + +/// ScalarRegexMatchExpr +/// Only used when evaluating regexp matching with literal pattern. +/// Example regex expression: c1 ~ '^a' / c1 !~ '^a' / c1 ~* '^a' / c1 !~* '^a'. +/// Literal regexp pattern will be compiled once and cached to be reused in execution. +/// It's will save compile time of pre execution and speed up execution. +#[derive(Clone)] +pub struct ScalarRegexMatchExpr { + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + compiled: Option, +} + +impl ScalarRegexMatchExpr { + pub fn new( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + ) -> Self { + let mut res = Self { + negated, + case_insensitive, + expr, + pattern, + compiled: None, + }; + res.compile().unwrap(); + res + } + + /// Is negated + pub fn negated(&self) -> bool { + self.negated + } + + /// Is case insensitive + pub fn case_insensitive(&self) -> bool { + self.case_insensitive + } + + /// Input expression + pub fn expr(&self) -> &Arc { + &self.expr + } + + /// Pattern expression + pub fn pattern(&self) -> &Arc { + &self.pattern + } + + /// Compile regex pattern + fn compile(&mut self) -> datafusion_common::Result<()> { + let scalar_pattern = + self.pattern + .as_any() + .downcast_ref::() + .and_then(|pattern| match pattern.value() { + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => Some(None), + ScalarValue::Utf8(Some(pattern)) + | ScalarValue::Utf8View(Some(pattern)) + | ScalarValue::LargeUtf8(Some(pattern)) => { + let mut pattern = pattern.to_string(); + if self.case_insensitive { + pattern = format!("(?i){}", pattern); + } + Some(Some(pattern)) + } + _ => None, + }); + match scalar_pattern { + Some(Some(scalar_pattern)) => Regex::new(scalar_pattern.as_str()) + .map(|compiled| { + self.compiled = Some(compiled); + }) + .map_err(|err| { + datafusion_common::DataFusionError::Internal(format!( + "Failed to compile regex: {}", + err + )) + }), + Some(None) => { + self.compiled = None; + Ok(()) + } + None => Err(datafusion_common::DataFusionError::Internal(format!( + "Regex pattern({}) isn't literal string", + self.pattern + ))), + } + } + + /// Operator name + fn op_name(&self) -> &str { + match (self.negated, self.case_insensitive) { + (false, false) => "MATCH", + (true, false) => "NOT MATCH", + (false, true) => "IMATCH", + (true, true) => "NOT IMATCH", + } + } +} + +impl ScalarRegexMatchExpr { + /// Evaluate the scalar regex match expression match array value + fn evaluate_array( + &self, + array: &Arc, + ) -> datafusion_common::Result { + macro_rules! downcast_string_array { + ($ARRAY:expr, $ARRAY_TYPE:ident, $ERR_MSG:expr) => { + &($ARRAY + .as_any() + .downcast_ref::<$ARRAY_TYPE>() + .expect($ERR_MSG)) + }; + } + match array.data_type() { + DataType::Null => { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + }, + DataType::Utf8 => array_regexp_match( + downcast_string_array!(array, StringArray, "Failed to downcast StringArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + DataType::Utf8View => array_regexp_match( + downcast_string_array!(array, StringViewArray, "Failed to downcast StringViewArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + DataType::LargeUtf8 => array_regexp_match( + downcast_string_array!(array, LargeStringArray, "Failed to downcast LargeStringArray"), + self.compiled.as_ref().unwrap(), + self.negated, + ), + other=> datafusion_common::internal_err!( + "Data type {:?} not supported for ScalarRegexMatchExpr, expect Utf8|Utf8View|LargeUtf8", other + ), + } + } + + /// Evaluate the scalar regex match expression match scalar value + fn evaluate_scalar( + &self, + scalar: &ScalarValue, + ) -> datafusion_common::Result { + match scalar { + ScalarValue::Null + | ScalarValue::Utf8(None) + | ScalarValue::Utf8View(None) + | ScalarValue::LargeUtf8(None) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))), + ScalarValue::Utf8(Some(scalar)) + | ScalarValue::Utf8View(Some(scalar)) + | ScalarValue::LargeUtf8(Some(scalar)) => { + let mut result = self.compiled.as_ref().unwrap().is_match(scalar); + if self.negated { + result = !result; + } + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(result)))) + }, + other=> datafusion_common::internal_err!( + "Data type {:?} not supported for ScalarRegexMatchExpr, expect Utf8|Utf8View|LargeUtf8", other + ), + } + } +} + +impl std::hash::Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + +impl std::cmp::PartialEq for ScalarRegexMatchExpr { + fn eq(&self, other: &Self) -> bool { + self.negated.eq(&other.negated) + && self.case_insensitive.eq(&self.case_insensitive) + && self.expr.eq(&other.expr) + && self.pattern.eq(&other.pattern) + } +} + +impl std::fmt::Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ScalarRegexMatchExpr") + .field("negated", &self.negated) + .field("case_insensitive", &self.case_insensitive) + .field("expr", &self.expr) + .field("pattern", &self.pattern) + .finish() + } +} + +impl std::fmt::Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) + } +} + +impl PhysicalExpr for ScalarRegexMatchExpr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn data_type( + &self, + _: &arrow_schema::Schema, + ) -> datafusion_common::Result { + Ok(DataType::Boolean) + } + + fn nullable( + &self, + input_schema: &arrow_schema::Schema, + ) -> datafusion_common::Result { + Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) + } + + fn evaluate( + &self, + batch: &arrow_array::RecordBatch, + ) -> datafusion_common::Result { + self.expr + .evaluate(batch) + .and_then(|lhs| { + if self.compiled.is_some() { + match &lhs { + ColumnarValue::Array(array) => self.evaluate_array(array), + ColumnarValue::Scalar(scalar) => self.evaluate_scalar(scalar), + } + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(None))) + } + }) + .and_then(|result| result.into_array(batch.num_rows())) + .map(ColumnarValue::Array) + } + + fn children(&self) -> Vec<&std::sync::Arc> { + vec![&self.expr, &self.pattern] + } + + fn with_new_children( + self: std::sync::Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(ScalarRegexMatchExpr::new( + self.negated, + self.case_insensitive, + Arc::clone(&children[0]), + Arc::clone(&children[1]), + ))) + } + + fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { + let mut s = state; + self.hash(&mut s); + } +} + +impl PartialEq for ScalarRegexMatchExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self == x) + .unwrap_or(false) + } +} + +/// It is used for scalar regexp matching and copy from arrow-rs +fn array_regexp_match( + array: &dyn ArrayAccessor, + regex: &Regex, + negated: bool, +) -> datafusion_common::Result { + let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); + let mut buffer_builder = BooleanBufferBuilder::new(array.len()); + + if regex.as_str().is_empty() { + buffer_builder.append_n(array.len(), true); + } else { + for i in 0..array.len() { + let value = array.value(i); + buffer_builder.append(regex.is_match(value)); + } + } + + let buffer = buffer_builder.into(); + let bool_array = BooleanArray::from(unsafe { + ArrayData::new_unchecked( + DataType::Boolean, + array.len(), + None, + null_bit_buffer, + 0, + vec![buffer], + vec![], + ) + }); + + let bool_array = if negated { + arrow::compute::kernels::boolean::not(&bool_array) + } else { + Ok(bool_array) + }; + + bool_array + .map_err(|err| { + datafusion_common::DataFusionError::Execution(format!( + "Failed to evaluate regex: {}", + err + )) + }) + .map(|bool_array| ColumnarValue::Array(Arc::new(bool_array))) +} + +/// Create a scalar regex match expression, erroring if the argument types are not compatible. +pub fn scalar_regex_match( + negated: bool, + case_insensitive: bool, + expr: Arc, + pattern: Arc, + input_schema: &Schema, +) -> datafusion_common::Result> { + let valid_data_type = |data_type: &DataType| { + if !matches!( + data_type, + DataType::Null | DataType::Utf8 | DataType::Utf8View | DataType::LargeUtf8 + ) { + return datafusion_common::internal_err!( + "The type {data_type} not supported for scalar_regex_match, expect Null|Utf8|Utf8View|LargeUtf8" + ); + } + Ok(()) + }; + + for arg_expr in [&expr, &pattern] { + arg_expr + .data_type(input_schema) + .and_then(|data_type| valid_data_type(&data_type))?; + } + + Ok(Arc::new(ScalarRegexMatchExpr::new( + negated, + case_insensitive, + expr, + pattern, + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{col, lit}; + use arrow::record_batch::RecordBatch; + use arrow_array::ArrayRef; + use arrow_array::NullArray; + use arrow_schema::Field; + use arrow_schema::Schema; + use rstest::rstest; + use std::sync::Arc; + + fn test_schema(typ: DataType) -> Schema { + Schema::new(vec![Field::new("c1", typ, false)]) + } + + #[rstest( + negated, case_insensitive, typ, a_vec, b_lit, c_vec, + case( + false, false, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + ), + case( + false, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + ), + case( + true, false, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + ), + case( + true, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + "^a", + Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + ), + case( + true, true, DataType::Utf8, + Arc::new(StringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + case( + false, false, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + ), + case( + false, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + ), + case( + true, false, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + ), + case( + true, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + ), + case( + true, true, DataType::LargeUtf8, + Arc::new(LargeStringArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::LargeUtf8(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + case( + false, false, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, false, false, false])), + ), + case( + false, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![true, false, true, false, false])), + ), + case( + true, false, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, true, true, true])), + ), + case( + true, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![false, true, false, true, true])), + ), + case( + true, true, DataType::Utf8View, + Arc::new(StringViewArray::from(vec!["abc", "bbb", "ABC", "ba", "cba"])), + ScalarValue::Utf8View(None), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + case( + true, true, DataType::Null, + Arc::new(NullArray::new(5)), + ScalarValue::Utf8View(Some("^a".to_string())), + Arc::new(BooleanArray::from(vec![None, None, None, None, None])), + ), + )] + fn test_scalar_regex_match_array( + negated: bool, + case_insensitive: bool, + typ: DataType, + a_vec: ArrayRef, + b_lit: impl datafusion_expr::Literal, + c_vec: ArrayRef, + ) { + let schema = test_schema(typ); + let left = col("c1", &schema).unwrap(); + let right = lit(b_lit); + + // verify that we can construct the expression + let expression = + scalar_regex_match(negated, case_insensitive, left, right, &schema).unwrap(); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![a_vec]).unwrap(); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema).unwrap(), DataType::Boolean); + + // compute + let result = expression + .evaluate(&batch) + .expect("Error evaluating expression"); + + if let ColumnarValue::Array(array) = result { + let array = array + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + + let c_vec = c_vec + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + // verify that the result is correct + assert_eq!(array, c_vec); + } else { + panic!("result was not an array"); + } + } + + #[rstest( + negated, case_insensitive, typ, a_lit, b_lit, flag, + case( + false, false, DataType::Utf8, "abc", "^a", Some(true), + ), + case( + false, true, DataType::Utf8, "Abc", "^a", Some(true), + ), + case( + true, false, DataType::Utf8, "abc", "^a", Some(false), + ), + case( + true, true, DataType::Utf8, "Abc", "^a", Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::Utf8(None), + None, + ), + case( + false, false, DataType::Utf8, + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(true), + ), + case( + false, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(true), + ), + case( + true, false, DataType::Utf8, + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(Some("^a".to_string())), + Some(false), + ), + case( + true, true, DataType::Utf8, + ScalarValue::Utf8(Some("Abc".to_string())), + ScalarValue::LargeUtf8(None), + None, + ), + )] + fn test_scalar_regex_match_scalar( + negated: bool, + case_insensitive: bool, + typ: DataType, + a_lit: impl datafusion_expr::Literal, + b_lit: impl datafusion_expr::Literal, + flag: Option, + ) { + let left = lit(a_lit); + let right = lit(b_lit); + let schema = test_schema(typ); + let expression = + scalar_regex_match(negated, case_insensitive, left, right, &schema).unwrap(); + let num_rows: usize = 3; + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(StringArray::from([""].repeat(num_rows)))], + ) + .unwrap(); + + // verify that the expression's type is correct + assert_eq!(expression.data_type(&schema).unwrap(), DataType::Boolean); + + // compute + let result = expression + .evaluate(&batch) + .expect("Error evaluating expression"); + + if let ColumnarValue::Array(array) = result { + let array = array + .as_any() + .downcast_ref::() + .expect("failed to downcast to BooleanArray"); + + // verify that the result is correct + let c_vec = [flag].repeat(batch.num_rows()); + assert_eq!(array, &BooleanArray::from(c_vec)); + } else { + panic!("result was not an array"); + } + } + + #[rstest( + expr, pattern, + case( + col("c1", &test_schema(DataType::Utf8)).unwrap(), + lit(1), + ), + case( + lit(1), + col("c1", &test_schema(DataType::Utf8)).unwrap(), + ), + )] + #[should_panic] + fn test_scalar_regex_match_panic( + expr: Arc, + pattern: Arc, + ) { + let _ = + scalar_regex_match(false, false, expr, pattern, &test_schema(DataType::Utf8)) + .unwrap(); + } + + #[rstest( + pattern, + case(col("c1", &test_schema(DataType::Utf8)).unwrap()), // not literal + case(lit(1)), // not literal string + case(lit("\\x{202e")), // wrong regex pattern + )] + #[should_panic] + fn test_scalar_regex_match_compile_error(pattern: Arc) { + let _ = ScalarRegexMatchExpr::new(false, false, lit("a"), pattern); + } +} diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index bffc2c46fc1e..3a60e0cfeb24 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::expressions::scalar_regex_match; use crate::scalar_function; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -191,7 +192,32 @@ pub fn create_physical_expr( // // There should be no coercion during physical // planning. - binary(lhs, *op, rhs, input_schema) + if let Expr::Literal( + ScalarValue::Null + | ScalarValue::Utf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::LargeUtf8(_), + ) = right.as_ref() + { + // handle literal regexp pattern case to `ScalarRegexMatchExpr` + match *op { + Operator::RegexMatch => { + scalar_regex_match(false, false, lhs, rhs, input_schema) + } + Operator::RegexNotMatch => { + scalar_regex_match(true, false, lhs, rhs, input_schema) + } + Operator::RegexIMatch => { + scalar_regex_match(false, true, lhs, rhs, input_schema) + } + Operator::RegexNotIMatch => { + scalar_regex_match(true, true, lhs, rhs, input_schema) + } + _ => binary(lhs, *op, rhs, input_schema), + } + } else { + binary(lhs, *op, rhs, input_schema) + } } Expr::Like(Like { negated, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d6fa129edc3f..d9c13d837491 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -841,6 +841,8 @@ message PhysicalExprNode { PhysicalLikeExprNode like_expr = 18; PhysicalExtensionExprNode extension = 19; + + PhysicalScalarRegexMatchExprNode scalar_regex_match_expr = 20; } } @@ -953,6 +955,13 @@ message PhysicalExtensionExprNode { repeated PhysicalExprNode inputs = 2; } +message PhysicalScalarRegexMatchExprNode { + bool negated = 1; + bool case_insensitive = 2; + PhysicalExprNode expr = 3; + PhysicalExprNode pattern = 4; +} + message FilterExecNode { PhysicalPlanNode input = 1; PhysicalExprNode expr = 2; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 16f14d9ddf61..2cda720f5d34 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -13932,6 +13932,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::Extension(v) => { struct_ser.serialize_field("extension", v)?; } + physical_expr_node::ExprType::ScalarRegexMatchExpr(v) => { + struct_ser.serialize_field("scalarRegexMatchExpr", v)?; + } } } struct_ser.end() @@ -13972,6 +13975,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "like_expr", "likeExpr", "extension", + "scalar_regex_match_expr", + "scalarRegexMatchExpr", ]; #[allow(clippy::enum_variant_names)] @@ -13993,6 +13998,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { ScalarUdf, LikeExpr, Extension, + ScalarRegexMatchExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -14031,6 +14037,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), "extension" => Ok(GeneratedField::Extension), + "scalarRegexMatchExpr" | "scalar_regex_match_expr" => Ok(GeneratedField::ScalarRegexMatchExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -14170,6 +14177,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("extension")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::Extension) +; + } + GeneratedField::ScalarRegexMatchExpr => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("scalarRegexMatchExpr")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::ScalarRegexMatchExpr) ; } } @@ -15627,6 +15641,149 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode { deserializer.deserialize_struct("datafusion.PhysicalPlanNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalScalarRegexMatchExprNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.negated { + len += 1; + } + if self.case_insensitive { + len += 1; + } + if self.expr.is_some() { + len += 1; + } + if self.pattern.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalScalarRegexMatchExprNode", len)?; + if self.negated { + struct_ser.serialize_field("negated", &self.negated)?; + } + if self.case_insensitive { + struct_ser.serialize_field("caseInsensitive", &self.case_insensitive)?; + } + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.pattern.as_ref() { + struct_ser.serialize_field("pattern", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalScalarRegexMatchExprNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "negated", + "case_insensitive", + "caseInsensitive", + "expr", + "pattern", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Negated, + CaseInsensitive, + Expr, + Pattern, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "negated" => Ok(GeneratedField::Negated), + "caseInsensitive" | "case_insensitive" => Ok(GeneratedField::CaseInsensitive), + "expr" => Ok(GeneratedField::Expr), + "pattern" => Ok(GeneratedField::Pattern), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalScalarRegexMatchExprNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalScalarRegexMatchExprNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut negated__ = None; + let mut case_insensitive__ = None; + let mut expr__ = None; + let mut pattern__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Negated => { + if negated__.is_some() { + return Err(serde::de::Error::duplicate_field("negated")); + } + negated__ = Some(map_.next_value()?); + } + GeneratedField::CaseInsensitive => { + if case_insensitive__.is_some() { + return Err(serde::de::Error::duplicate_field("caseInsensitive")); + } + case_insensitive__ = Some(map_.next_value()?); + } + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::Pattern => { + if pattern__.is_some() { + return Err(serde::de::Error::duplicate_field("pattern")); + } + pattern__ = map_.next_value()?; + } + } + } + Ok(PhysicalScalarRegexMatchExprNode { + negated: negated__.unwrap_or_default(), + case_insensitive: case_insensitive__.unwrap_or_default(), + expr: expr__, + pattern: pattern__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalScalarRegexMatchExprNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalScalarUdfNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 59a90eb31ade..fe871b387367 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1162,7 +1162,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20" )] pub expr_type: ::core::option::Option, } @@ -1211,6 +1211,10 @@ pub mod physical_expr_node { LikeExpr(::prost::alloc::boxed::Box), #[prost(message, tag = "19")] Extension(super::PhysicalExtensionExprNode), + #[prost(message, tag = "20")] + ScalarRegexMatchExpr( + ::prost::alloc::boxed::Box, + ), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1390,6 +1394,18 @@ pub struct PhysicalExtensionExprNode { pub inputs: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalScalarRegexMatchExprNode { + #[prost(bool, tag = "1")] + pub negated: bool, + #[prost(bool, tag = "2")] + pub case_insensitive: bool, + #[prost(message, optional, boxed, tag = "3")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, boxed, tag = "4")] + pub pattern: ::core::option::Option<::prost::alloc::boxed::Box>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")] pub input: ::core::option::Option<::prost::alloc::boxed::Box>, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 31b59c2a9457..ff086f4b215e 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -38,7 +38,7 @@ use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, - Literal, NegativeExpr, NotExpr, TryCastExpr, + Literal, NegativeExpr, NotExpr, ScalarRegexMatchExpr, TryCastExpr, }; use datafusion::physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion::physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; @@ -393,6 +393,26 @@ pub fn parse_physical_expr( .collect::>()?; (codec.try_decode_expr(extension.expr.as_slice(), &inputs)?) as _ } + ExprType::ScalarRegexMatchExpr(scalar_match_expr) => { + Arc::new(ScalarRegexMatchExpr::new( + scalar_match_expr.negated, + scalar_match_expr.case_insensitive, + parse_required_physical_expr( + scalar_match_expr.expr.as_deref(), + registry, + "expr", + input_schema, + codec, + )?, + parse_required_physical_expr( + scalar_match_expr.pattern.as_deref(), + registry, + "pattern", + input_schema, + codec, + )?, + )) + } }; Ok(pexpr) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e9bae11bad2c..c2f4cad9ebf9 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,7 +23,7 @@ use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr use datafusion::physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - Literal, NegativeExpr, NotExpr, NthValue, TryCastExpr, + Literal, NegativeExpr, NotExpr, NthValue, ScalarRegexMatchExpr, TryCastExpr, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -366,6 +366,25 @@ pub fn serialize_physical_expr( }, ))), }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::ScalarRegexMatchExpr(Box::new( + protobuf::PhysicalScalarRegexMatchExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern(), + codec, + )?)), + }, + )), + ), + }) } else { let mut buf: Vec = vec![]; match codec.try_encode_expr(value, &mut buf) { From 8e92a1c15c28ddb10a767a7430ce098bf2f94fe8 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Wed, 18 Sep 2024 00:36:33 +0800 Subject: [PATCH 2/4] bench: add scalar regex match benchmarks --- datafusion/physical-expr/Cargo.toml | 4 + .../benches/scalar_regex_match.rs | 121 ++++++++++++++++++ 2 files changed, 125 insertions(+) create mode 100644 datafusion/physical-expr/benches/scalar_regex_match.rs diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 4195e684381f..26b30029f63d 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -75,3 +75,7 @@ name = "case_when" [[bench]] harness = false name = "is_null" + +[[bench]] +harness = false +name = "scalar_regex_match" diff --git a/datafusion/physical-expr/benches/scalar_regex_match.rs b/datafusion/physical-expr/benches/scalar_regex_match.rs new file mode 100644 index 000000000000..680843c0cb56 --- /dev/null +++ b/datafusion/physical-expr/benches/scalar_regex_match.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_expr_common::operator::Operator; +use datafusion_physical_expr::expressions::{binary, col, lit, scalar_regex_match}; +use hashbrown::HashMap; +use rand::distributions::{Alphanumeric, DistString}; + +/// make a record batch with one column and n rows +/// this record batch is single string column is used for +/// scalar regex match benchmarks +fn make_record_batch(rows: usize, string_length: usize, schema: Schema) -> RecordBatch { + let mut rng = rand::thread_rng(); + let mut array = Vec::with_capacity(rows); + for _ in 0..rows { + let data_line = Alphanumeric.sample_string(&mut rng, string_length); + array.push(Some(data_line)); + } + let array = StringArray::from(array); + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).unwrap() +} + +fn scalar_regex_match_benchmark(c: &mut Criterion) { + // make common schema + let column = "string"; + let schema = Schema::new(vec![Field::new(column, DataType::Utf8, true)]); + + // meke test record batch + let test_batch = [ + (10, make_record_batch(10, 100, schema.clone())), + (100, make_record_batch(100, 100, schema.clone())), + (1000, make_record_batch(1000, 100, schema.clone())), + (2000, make_record_batch(2000, 100, schema.clone())), + ] + .iter() + .map(|(k, v)| (*k, v.clone())) + .collect::>(); + + // string column + let string_col = col(column, &schema).unwrap(); + + // some pattern literal + let pattern_lit = [ + ("email".to_string(), lit(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")), + ("url".to_string(), lit(r"^(https?|ftp)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]$")), + ("ip".to_string(), lit(r"^((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$")), + ("phone".to_string(), lit(r"^(\+\d{1,2}\s?)?\(?\d{3}\)?[\s.-]?\d{3}[\s.-]?\d{4}$")), + ("zip_code".to_string(), lit(r"^\d{5}(?:[-\s]\d{4})?$")), + ].iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + for (name, regexp_lit) in pattern_lit.iter() { + for (rows, batch) in test_batch.iter() { + for iter in [10, 20, 50, 100] { + // scalar regex match benchmarks + let bench_name = format!( + "scalar_regex_match_pattern_{}_rows_{}_iter_{}", + name, rows, iter + ); + c.bench_function(bench_name.as_str(), |b| { + let expr = scalar_regex_match( + false, + false, + string_col.clone(), + regexp_lit.clone(), + &schema, + ) + .unwrap(); + b.iter(|| { + for _ in 0..iter { + expr.evaluate(black_box(batch)).unwrap(); + } + }); + }); + + // binary regex match benchmarks + let bench_name = format!( + "binary_regex_match_pattern_{}_rows_{}_iter_{}", + name, rows, iter + ); + c.bench_function(bench_name.as_str(), |b| { + let expr = binary( + string_col.clone(), + Operator::RegexMatch, + regexp_lit.clone(), + &schema, + ) + .unwrap(); + b.iter(|| { + for _ in 0..iter { + expr.evaluate(black_box(batch)).unwrap(); + } + }); + }); + } + } + } +} + +criterion_group!(benches, scalar_regex_match_benchmark); +criterion_main!(benches); From 0e939ed5c3d48d95b95ad99efee894341018a868 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Wed, 18 Sep 2024 20:52:15 +0800 Subject: [PATCH 3/4] feat: apply scalar_regex_match optimize to similar_to case --- datafusion/physical-expr/src/planner.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 3a60e0cfeb24..2bc9332d908e 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -255,6 +255,23 @@ pub fn create_physical_expr( create_physical_expr(expr, input_dfschema, execution_props)?; let physical_pattern = create_physical_expr(pattern, input_dfschema, execution_props)?; + + if let Expr::Literal( + ScalarValue::Null + | ScalarValue::Utf8(_) + | ScalarValue::Utf8View(_) + | ScalarValue::LargeUtf8(_), + ) = pattern.as_ref() + { + // handle literal regexp pattern case to `ScalarRegexMatchExpr` + return scalar_regex_match( + *negated, + *case_insensitive, + physical_expr, + physical_pattern, + input_schema, + ); + } similar_to(*negated, *case_insensitive, physical_expr, physical_pattern) } Expr::Case(case) => { From 493a47afede9363335a0ab66b6805ebed4bc7e67 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Wed, 18 Sep 2024 22:17:21 +0800 Subject: [PATCH 4/4] minor: regen datafusion protobuf --- datafusion/physical-expr/Cargo.toml | 1 + .../src/expressions/scalar_regex_match.rs | 100 +++++++++--------- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/prost.rs | 1 - 4 files changed, 53 insertions(+), 51 deletions(-) diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 26b30029f63d..079e7d42e93e 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -56,6 +56,7 @@ itertools = { workspace = true, features = ["use_std"] } log = { workspace = true } paste = "^1.0" petgraph = "0.6.2" +regex = { workspace = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs index badb00659576..cc446f3328d5 100644 --- a/datafusion/physical-expr/src/expressions/scalar_regex_match.rs +++ b/datafusion/physical-expr/src/expressions/scalar_regex_match.rs @@ -22,11 +22,16 @@ use arrow_array::{ }; use arrow_buffer::BooleanBufferBuilder; use arrow_schema::{DataType, Schema}; -use datafusion_common::ScalarValue; +use datafusion_common::{Result as DFResult, ScalarValue}; use datafusion_expr::ColumnarValue; -use datafusion_physical_expr_common::physical_expr::{down_cast_any_ref, PhysicalExpr}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use regex::Regex; -use std::{any::Any, hash::Hash, sync::Arc}; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter, Result as FmtResult}, + hash::Hash, + sync::Arc, +}; /// ScalarRegexMatchExpr /// Only used when evaluating regexp matching with literal pattern. @@ -133,9 +138,7 @@ impl ScalarRegexMatchExpr { (true, true) => "NOT IMATCH", } } -} -impl ScalarRegexMatchExpr { /// Evaluate the scalar regex match expression match array value fn evaluate_array( &self, @@ -200,16 +203,9 @@ impl ScalarRegexMatchExpr { } } -impl std::hash::Hash for ScalarRegexMatchExpr { - fn hash(&self, state: &mut H) { - self.negated.hash(state); - self.case_insensitive.hash(state); - self.expr.hash(state); - self.pattern.hash(state); - } -} +impl Eq for ScalarRegexMatchExpr {} -impl std::cmp::PartialEq for ScalarRegexMatchExpr { +impl PartialEq for ScalarRegexMatchExpr { fn eq(&self, other: &Self) -> bool { self.negated.eq(&other.negated) && self.case_insensitive.eq(&self.case_insensitive) @@ -218,8 +214,17 @@ impl std::cmp::PartialEq for ScalarRegexMatchExpr { } } -impl std::fmt::Debug for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Hash for ScalarRegexMatchExpr { + fn hash(&self, state: &mut H) { + self.negated.hash(state); + self.case_insensitive.hash(state); + self.expr.hash(state); + self.pattern.hash(state); + } +} + +impl Debug for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { f.debug_struct("ScalarRegexMatchExpr") .field("negated", &self.negated) .field("case_insensitive", &self.case_insensitive) @@ -229,35 +234,26 @@ impl std::fmt::Debug for ScalarRegexMatchExpr { } } -impl std::fmt::Display for ScalarRegexMatchExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +impl Display for ScalarRegexMatchExpr { + fn fmt(&self, f: &mut Formatter) -> FmtResult { write!(f, "{} {} {}", self.expr, self.op_name(), self.pattern) } } impl PhysicalExpr for ScalarRegexMatchExpr { - fn as_any(&self) -> &dyn std::any::Any { + fn as_any(&self) -> &dyn Any { self } - fn data_type( - &self, - _: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn data_type(&self, _: &Schema) -> DFResult { Ok(DataType::Boolean) } - fn nullable( - &self, - input_schema: &arrow_schema::Schema, - ) -> datafusion_common::Result { + fn nullable(&self, input_schema: &Schema) -> DFResult { Ok(self.expr.nullable(input_schema)? || self.pattern.nullable(input_schema)?) } - fn evaluate( - &self, - batch: &arrow_array::RecordBatch, - ) -> datafusion_common::Result { + fn evaluate(&self, batch: &arrow_array::RecordBatch) -> DFResult { self.expr .evaluate(batch) .and_then(|lhs| { @@ -274,14 +270,14 @@ impl PhysicalExpr for ScalarRegexMatchExpr { .map(ColumnarValue::Array) } - fn children(&self) -> Vec<&std::sync::Arc> { + fn children(&self) -> Vec<&Arc> { vec![&self.expr, &self.pattern] } fn with_new_children( - self: std::sync::Arc, - children: Vec>, - ) -> datafusion_common::Result> { + self: Arc, + children: Vec>, + ) -> DFResult> { Ok(Arc::new(ScalarRegexMatchExpr::new( self.negated, self.case_insensitive, @@ -290,18 +286,24 @@ impl PhysicalExpr for ScalarRegexMatchExpr { ))) } - fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for ScalarRegexMatchExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self == x) - .unwrap_or(false) + fn evaluate_selection( + &self, + batch: &arrow_array::RecordBatch, + selection: &BooleanArray, + ) -> DFResult { + let tmp_batch = arrow::compute::filter_record_batch(batch, selection)?; + + let tmp_result = self.evaluate(&tmp_batch)?; + + if batch.num_rows() == tmp_batch.num_rows() { + // All values from the `selection` filter are true. + Ok(tmp_result) + } else if let ColumnarValue::Array(a) = tmp_result { + datafusion_physical_expr_common::utils::scatter(selection, a.as_ref()) + .map(ColumnarValue::Array) + } else { + Ok(tmp_result) + } } } @@ -310,7 +312,7 @@ fn array_regexp_match( array: &dyn ArrayAccessor, regex: &Regex, negated: bool, -) -> datafusion_common::Result { +) -> DFResult { let null_bit_buffer = array.nulls().map(|x| x.inner().sliced()); let mut buffer_builder = BooleanBufferBuilder::new(array.len()); @@ -359,7 +361,7 @@ pub fn scalar_regex_match( expr: Arc, pattern: Arc, input_schema: &Schema, -) -> datafusion_common::Result> { +) -> DFResult> { let valid_data_type = |data_type: &DataType| { if !matches!( data_type, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d9c13d837491..1fefce16c789 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -841,7 +841,7 @@ message PhysicalExprNode { PhysicalLikeExprNode like_expr = 18; PhysicalExtensionExprNode extension = 19; - + PhysicalScalarRegexMatchExprNode scalar_regex_match_expr = 20; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index fe871b387367..894ce5c5a525 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1404,7 +1404,6 @@ pub struct PhysicalScalarRegexMatchExprNode { #[prost(message, optional, boxed, tag = "4")] pub pattern: ::core::option::Option<::prost::alloc::boxed::Box>, } -#[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FilterExecNode { #[prost(message, optional, boxed, tag = "1")]