From c8b2ac5b5a6fb696aa98d7c952f90bf547869574 Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Wed, 10 Apr 2024 22:54:06 +0800 Subject: [PATCH 1/4] feat: add create_adaptive_array_iter macro --- datafusion/functions/Cargo.toml | 5 +++ datafusion/functions/benches/overlay.rs | 51 ++++++++++++++++++++++ datafusion/functions/src/string/overlay.rs | 40 ++++++++++++++--- datafusion/functions/src/utils.rs | 12 +++++ 4 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 datafusion/functions/benches/overlay.rs diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index f9985069413b..48de37a0084b 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -133,3 +133,8 @@ required-features = ["string_expressions"] harness = false name = "upper" required-features = ["string_expressions"] + +[[bench]] +harness = false +name = "overlay" +required-features = ["string_expressions"] diff --git a/datafusion/functions/benches/overlay.rs b/datafusion/functions/benches/overlay.rs new file mode 100644 index 000000000000..5d8dbaee5889 --- /dev/null +++ b/datafusion/functions/benches/overlay.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::StringArray; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_functions::string; +use std::sync::Arc; + +fn create_4args(size: usize) -> Vec { + let array: StringArray = std::iter::repeat(Some("Txxxxas")).take(size).collect(); + let characters = ScalarValue::Utf8(Some("hom".to_string())); + let pos = ScalarValue::Int64(Some(2)); + let len = ScalarValue::Int64(Some(4)); + vec![ + ColumnarValue::Array(Arc::new(array)), + ColumnarValue::Scalar(characters), + ColumnarValue::Scalar(pos), + ColumnarValue::Scalar(len), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + let overlay = string::overlay(); + for size in [1024, 4096, 8192] { + let args = create_4args(size); + let mut group = c.benchmark_group("overlay_with_4args"); + group.bench_function(BenchmarkId::new("overlay", size), |b| { + b.iter(|| criterion::black_box(overlay.invoke(&args).unwrap())) + }); + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 3f92a73c1af9..b5c2ef1721b9 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -18,6 +18,7 @@ use std::any::Any; use std::sync::Arc; +use arrow::array::Array; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -26,7 +27,9 @@ use datafusion_common::{exec_err, Result}; use datafusion_expr::TypeSignature::*; use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; +use datafusion_physical_expr::functions::Hint; +use crate::create_adaptive_array_iter; use crate::utils::{make_scalar_function, utf8_to_str_type}; #[derive(Debug)] @@ -70,8 +73,24 @@ impl ScalarUDFImpl for OverlayFunc { fn invoke(&self, args: &[ColumnarValue]) -> Result { match args[0].data_type() { - DataType::Utf8 => make_scalar_function(overlay::, vec![])(args), - DataType::LargeUtf8 => make_scalar_function(overlay::, vec![])(args), + DataType::Utf8 => make_scalar_function( + overlay::, + vec![ + Hint::Pad, + Hint::AcceptsSingular, + Hint::AcceptsSingular, + Hint::AcceptsSingular, + ], + )(args), + DataType::LargeUtf8 => make_scalar_function( + overlay::, + vec![ + Hint::Pad, + Hint::AcceptsSingular, + Hint::AcceptsSingular, + Hint::AcceptsSingular, + ], + )(args), other => exec_err!("Unsupported data type {other:?} for function overlay"), } } @@ -88,10 +107,13 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let characters_array = as_generic_string_array::(&args[1])?; let pos_num = as_int64_array(&args[2])?; + let characters_array_iter = create_adaptive_array_iter!(characters_array); + let pos_num_iter = create_adaptive_array_iter!(pos_num); + let result = string_array .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) + .zip(characters_array_iter) + .zip(pos_num_iter) .map(|((string, characters), start_pos)| { match (string, characters, start_pos) { (Some(string), Some(characters), Some(start_pos)) => { @@ -126,11 +148,15 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let pos_num = as_int64_array(&args[2])?; let len_num = as_int64_array(&args[3])?; + let characters_array_iter = create_adaptive_array_iter!(characters_array); + let pos_num_iter = create_adaptive_array_iter!(pos_num); + let len_num_iter = create_adaptive_array_iter!(len_num); + let result = string_array .iter() - .zip(characters_array.iter()) - .zip(pos_num.iter()) - .zip(len_num.iter()) + .zip(characters_array_iter) + .zip(pos_num_iter) + .zip(len_num_iter) .map(|(((string, characters), start_pos), len)| { match (string, characters, start_pos, len) { (Some(string), Some(characters), Some(start_pos), Some(len)) => { diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 9b7144b483bd..e8703b3e069d 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -116,6 +116,18 @@ where }) } +#[macro_export] +macro_rules! create_adaptive_array_iter { + ($ARRAY:expr) => {{ + let first_value = if $ARRAY.is_null(0) { + None + } else { + Some($ARRAY.value(0)) + }; + $ARRAY.iter().chain(std::iter::repeat(first_value)) + }}; +} + #[cfg(test)] pub mod test { /// $FUNC ScalarUDFImpl to test From f7e6102d8dd4f7d21b6cc8eb64cea11f78fae08d Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Fri, 12 Apr 2024 15:22:24 +0800 Subject: [PATCH 2/4] feat: add adaptive_array_iter function --- datafusion/functions/benches/overlay.rs | 39 +++++++++++++++++++--- datafusion/functions/src/string/overlay.rs | 14 ++++---- datafusion/functions/src/utils.rs | 35 +++++++++++++------ 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/datafusion/functions/benches/overlay.rs b/datafusion/functions/benches/overlay.rs index 5d8dbaee5889..b98ff4e8ae9c 100644 --- a/datafusion/functions/benches/overlay.rs +++ b/datafusion/functions/benches/overlay.rs @@ -15,15 +15,18 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::StringArray; +use arrow::array::{Array, Int64Array, StringArray}; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; use datafusion_functions::string; use std::sync::Arc; -fn create_4args(size: usize) -> Vec { +/// Create four args, three of which are Scalars and one is a StringArray. +/// The `size` represents the length of the StringArray. +fn create_4args_with_3scalars(size: usize) -> Vec { let array: StringArray = std::iter::repeat(Some("Txxxxas")).take(size).collect(); + assert_eq!(array.len(), size); let characters = ScalarValue::Utf8(Some("hom".to_string())); let pos = ScalarValue::Int64(Some(2)); let len = ScalarValue::Int64(Some(4)); @@ -35,11 +38,37 @@ fn create_4args(size: usize) -> Vec { ] } +/// Create four args, all of which are Arrays. +/// The `size` represents the length of Array. +fn create_4args_without_scalar(size: usize) -> Vec { + let array: StringArray = std::iter::repeat(Some("Txxxxas")).take(size).collect(); + let characters: StringArray = std::iter::repeat(Some("hom")).take(size).collect(); + let pos: Int64Array = std::iter::repeat(Some(2)).take(size).collect(); + let len: Int64Array = std::iter::repeat(Some(4)).take(size).collect(); + vec![ + ColumnarValue::Array(Arc::new(array)), + ColumnarValue::Array(Arc::new(characters)), + ColumnarValue::Array(Arc::new(pos)), + ColumnarValue::Array(Arc::new(len)), + ] +} + fn criterion_benchmark(c: &mut Criterion) { let overlay = string::overlay(); - for size in [1024, 4096, 8192] { - let args = create_4args(size); - let mut group = c.benchmark_group("overlay_with_4args"); + let sizes: Vec = vec![1024, 4096, 8192]; + + for size in &sizes { + let args = create_4args_with_3scalars(*size); + let mut group = c.benchmark_group("4args_with_3scalars"); + group.bench_function(BenchmarkId::new("overlay", size), |b| { + b.iter(|| criterion::black_box(overlay.invoke(&args).unwrap())) + }); + group.finish(); + } + + for size in &sizes { + let args = create_4args_without_scalar(*size); + let mut group = c.benchmark_group("4args_without_scalar"); group.bench_function(BenchmarkId::new("overlay", size), |b| { b.iter(|| criterion::black_box(overlay.invoke(&args).unwrap())) }); diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index b5c2ef1721b9..65233ebb34e2 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -18,7 +18,6 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::Array; use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; @@ -29,8 +28,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_physical_expr::functions::Hint; -use crate::create_adaptive_array_iter; -use crate::utils::{make_scalar_function, utf8_to_str_type}; +use crate::utils::{adaptive_array_iter, make_scalar_function, utf8_to_str_type}; #[derive(Debug)] pub struct OverlayFunc { @@ -107,8 +105,8 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let characters_array = as_generic_string_array::(&args[1])?; let pos_num = as_int64_array(&args[2])?; - let characters_array_iter = create_adaptive_array_iter!(characters_array); - let pos_num_iter = create_adaptive_array_iter!(pos_num); + let characters_array_iter = adaptive_array_iter(characters_array.iter()); + let pos_num_iter = adaptive_array_iter(pos_num.iter()); let result = string_array .iter() @@ -148,9 +146,9 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let pos_num = as_int64_array(&args[2])?; let len_num = as_int64_array(&args[3])?; - let characters_array_iter = create_adaptive_array_iter!(characters_array); - let pos_num_iter = create_adaptive_array_iter!(pos_num); - let len_num_iter = create_adaptive_array_iter!(len_num); + let characters_array_iter = adaptive_array_iter(characters_array.iter()); + let pos_num_iter = adaptive_array_iter(pos_num.iter()); + let len_num_iter = adaptive_array_iter(len_num.iter()); let result = string_array .iter() diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e8703b3e069d..2e8f22d49ad4 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::ArrayRef; +use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef}; use arrow::datatypes::DataType; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; use datafusion_physical_expr::functions::Hint; +use itertools::Either; use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given @@ -116,16 +117,28 @@ where }) } -#[macro_export] -macro_rules! create_adaptive_array_iter { - ($ARRAY:expr) => {{ - let first_value = if $ARRAY.is_null(0) { - None - } else { - Some($ARRAY.value(0)) - }; - $ARRAY.iter().chain(std::iter::repeat(first_value)) - }}; +/// Create an adaptive iterator. When the `hints` args of `make_scalar_function` +/// includes `Hint::AcceptsSingular`, this function can be used to wrap an `ArrayIter` +/// that contains only one value into an `Iterator` that contains multiple values, +/// facilitating `zip` operations. +/// NOTE: +/// 1. When using this function, be sure to ensure that the corresponding `Hint` for +/// `array_iter` must be `Hint::AcceptsSingular`. +/// 2. You cannot call this function on all `args` of `inner` at the same time; there +/// is a risk of never being able to exit the iteration! +pub(super) fn adaptive_array_iter<'a, T>( + mut array_iter: ArrayIter, +) -> impl Iterator> + 'a +where + T: ArrayAccessor + 'a, + T::Item: Copy, +{ + if array_iter.len() == 1 { + let value = array_iter.next().expect("Contains a value"); + Either::Left(std::iter::repeat(value).into_iter()) + } else { + Either::Right(array_iter.into_iter()) + } } #[cfg(test)] From d1a17ba465c8bdaa7ca758d0599f29106cce83dc Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Fri, 12 Apr 2024 15:33:48 +0800 Subject: [PATCH 3/4] chore: pass clippy --- datafusion/functions/src/utils.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 2e8f22d49ad4..2431d77c564d 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -124,8 +124,8 @@ where /// NOTE: /// 1. When using this function, be sure to ensure that the corresponding `Hint` for /// `array_iter` must be `Hint::AcceptsSingular`. -/// 2. You cannot call this function on all `args` of `inner` at the same time; there -/// is a risk of never being able to exit the iteration! +/// 2. You cannot call this function on all `args` of `inner` of `make_scalar_function` +/// at the same time; there is a risk of never being able to exit the iteration! pub(super) fn adaptive_array_iter<'a, T>( mut array_iter: ArrayIter, ) -> impl Iterator> + 'a @@ -135,9 +135,9 @@ where { if array_iter.len() == 1 { let value = array_iter.next().expect("Contains a value"); - Either::Left(std::iter::repeat(value).into_iter()) + Either::Left(std::iter::repeat(value)) } else { - Either::Right(array_iter.into_iter()) + Either::Right(array_iter) } } From ad1a80d04272c8215821dd9897ffe8050557db2c Mon Sep 17 00:00:00 2001 From: JasonLi-cn Date: Sun, 21 Apr 2024 16:13:12 +0800 Subject: [PATCH 4/4] feat: update overlay --- datafusion/functions/Cargo.toml | 1 + datafusion/functions/src/macros.rs | 234 +++++++++++++++++++++ datafusion/functions/src/string/overlay.rs | 169 ++++++++------- datafusion/functions/src/utils.rs | 76 ++++--- 4 files changed, 379 insertions(+), 101 deletions(-) diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index 48de37a0084b..e74f116f02ef 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -63,6 +63,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +bitflags = "2.5.0" base64 = { version = "0.22", optional = true } blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs index 5ee47bd3e8eb..63d2dc936cd0 100644 --- a/datafusion/functions/src/macros.rs +++ b/datafusion/functions/src/macros.rs @@ -396,3 +396,237 @@ macro_rules! make_function_inputs2 { .collect::<$ARRAY_TYPE1>() }}; } + +macro_rules! array_iter { + ($ARRAY:expr) => {{ + $ARRAY.iter() + }}; +} + +macro_rules! scalar_iter { + ($ARRAY:expr) => {{ + let value = if $ARRAY.is_null(0) { + None + } else { + Some($ARRAY.value(0)) + }; + std::iter::repeat(value) + }}; +} + +macro_rules! make_function_3args { + (ScalarFlags::None, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(array_iter!($ARG0), array_iter!($ARG1), array_iter!($ARG2)) + }; + (ScalarFlags::A, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(scalar_iter!($ARG0), array_iter!($ARG1), array_iter!($ARG2)) + }; + (ScalarFlags::B, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(array_iter!($ARG0), scalar_iter!($ARG1), array_iter!($ARG2)) + }; + (ScalarFlags::C, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(array_iter!($ARG0), array_iter!($ARG1), scalar_iter!($ARG2)) + }; + (ScalarFlags::AB, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(scalar_iter!($ARG0), scalar_iter!($ARG1), array_iter!($ARG2)) + }; + (ScalarFlags::AC, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(scalar_iter!($ARG0), array_iter!($ARG1), scalar_iter!($ARG2)) + }; + (ScalarFlags::BC, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(array_iter!($ARG0), scalar_iter!($ARG1), scalar_iter!($ARG2)) + }; + (ScalarFlags::ABC, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + $FUNC(array_iter!($ARG0), array_iter!($ARG1), array_iter!($ARG2)) + }; +} + +macro_rules! make_function_4args { + (ScalarFlags::None, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + array_iter!($ARG1), + array_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::A, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + array_iter!($ARG1), + array_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::B, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + scalar_iter!($ARG1), + array_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::C, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + array_iter!($ARG1), + scalar_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::D, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + array_iter!($ARG1), + array_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::AB, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + scalar_iter!($ARG1), + array_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::AC, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + array_iter!($ARG1), + scalar_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::AD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + array_iter!($ARG1), + array_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::BC, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + scalar_iter!($ARG1), + scalar_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::BD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + scalar_iter!($ARG1), + array_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::CD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + array_iter!($ARG1), + scalar_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::ABC, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + scalar_iter!($ARG1), + scalar_iter!($ARG2), + array_iter!($ARG3), + ) + }; + (ScalarFlags::ABD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + scalar_iter!($ARG1), + array_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::ACD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + scalar_iter!($ARG0), + array_iter!($ARG1), + scalar_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::BCD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + $FUNC( + array_iter!($ARG0), + scalar_iter!($ARG1), + scalar_iter!($ARG2), + scalar_iter!($ARG3), + ) + }; + (ScalarFlags::ABCD, $FUNC:ident, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + // all args use array_iter + $FUNC( + array_iter!($ARG0), + array_iter!($ARG1), + array_iter!($ARG2), + array_iter!($ARG3), + ) + }; +} + +macro_rules! invoke_function_impl { + ($FLAG:expr, $FUNC:expr, $ARG0:expr, $ARG1:expr, $ARG2:expr, [$($FLAG_ITEM:tt),+]) => { + match $FLAG { + $(ScalarFlags::$FLAG_ITEM => { + let func = $FUNC; + make_function_3args!( + ScalarFlags::$FLAG_ITEM, + func, + $ARG0, + $ARG1, + $ARG2 + ) + }),+ + _ => unreachable!("{:?}", $FLAG), + } + }; + ($FLAG:expr, $FUNC:expr, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr, [$($FLAG_ITEM:tt),+]) => { + match $FLAG { + $(ScalarFlags::$FLAG_ITEM => { + let func = $FUNC; + make_function_4args!( + ScalarFlags::$FLAG_ITEM, + func, + $ARG0, + $ARG1, + $ARG2, + $ARG3 + ) + }),+ + _ => unreachable!("{:?}", $FLAG), + } + }; +} + +macro_rules! invoke_function { + ($FLAG:expr, $FUNC:expr, $ARG0:expr, $ARG1:expr, $ARG2:expr) => { + invoke_function_impl!( + $FLAG, + $FUNC, + $ARG0, + $ARG1, + $ARG2, + [None, A, B, C, AB, AC, BC, ABC] + ) + }; + ($FLAG:expr, $FUNC:expr, $ARG0:expr, $ARG1:expr, $ARG2:expr, $ARG3:expr) => { + invoke_function_impl!( + $FLAG, + $FUNC, + $ARG0, + $ARG1, + $ARG2, + $ARG3, + [None, A, B, C, D, AB, AC, AD, BC, BD, CD, ABC, ABD, ACD, BCD, ABCD] + ) + }; +} diff --git a/datafusion/functions/src/string/overlay.rs b/datafusion/functions/src/string/overlay.rs index 65233ebb34e2..13280f665e6f 100644 --- a/datafusion/functions/src/string/overlay.rs +++ b/datafusion/functions/src/string/overlay.rs @@ -18,7 +18,7 @@ use std::any::Any; use std::sync::Arc; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow::array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; use arrow::datatypes::DataType; use datafusion_common::cast::{as_generic_string_array, as_int64_array}; @@ -28,7 +28,7 @@ use datafusion_expr::{ColumnarValue, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use datafusion_physical_expr::functions::Hint; -use crate::utils::{adaptive_array_iter, make_scalar_function, utf8_to_str_type}; +use crate::utils::{make_scalar_function, utf8_to_str_type, ScalarFlags}; #[derive(Debug)] pub struct OverlayFunc { @@ -94,6 +94,80 @@ impl ScalarUDFImpl for OverlayFunc { } } +pub fn overlay3<'a, T: OffsetSizeTrait>( + string_array_iter: impl Iterator>, + characters_array_iter: impl Iterator>, + pos_num_iter: impl Iterator>, +) -> Result { + let result = string_array_iter + .zip(characters_array_iter) + .zip(pos_num_iter) + .map(|((string, characters), start_pos)| { + match (string, characters, start_pos) { + (Some(string), Some(characters), Some(start_pos)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = characters_len as i64; + let mut res = String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) +} + +pub fn overlay4<'a, T: OffsetSizeTrait>( + string_array_iter: impl Iterator>, + characters_array_iter: impl Iterator>, + pos_num_iter: impl Iterator>, + len_num_iter: impl Iterator>, +) -> Result { + let result = string_array_iter + .zip(characters_array_iter) + .zip(pos_num_iter) + .zip(len_num_iter) + .map(|(((string, characters), start_pos), len)| { + match (string, characters, start_pos, len) { + (Some(string), Some(characters), Some(start_pos), Some(len)) => { + let string_len = string.chars().count(); + let characters_len = characters.chars().count(); + let replace_len = len.min(string_len as i64); + let mut res = String::with_capacity(string_len.max(characters_len)); + + //as sql replace index start from 1 while string index start from 0 + if start_pos > 1 && start_pos - 1 < string_len as i64 { + let start = (start_pos - 1) as usize; + res.push_str(&string[..start]); + } + res.push_str(characters); + // if start + replace_len - 1 >= string_length, just to string end + if start_pos + replace_len - 1 < string_len as i64 { + let end = (start_pos + replace_len - 1) as usize; + res.push_str(&string[end..]); + } + Ok(Some(res)) + } + _ => Ok(None), + } + }) + .collect::>>()?; + Ok(Arc::new(result) as ArrayRef) +} + /// OVERLAY(string1 PLACING string2 FROM integer FOR integer2) /// Replaces a substring of string1 with string2 starting at the integer bit /// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas @@ -104,84 +178,31 @@ pub fn overlay(args: &[ArrayRef]) -> Result { let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let pos_num = as_int64_array(&args[2])?; - - let characters_array_iter = adaptive_array_iter(characters_array.iter()); - let pos_num_iter = adaptive_array_iter(pos_num.iter()); - - let result = string_array - .iter() - .zip(characters_array_iter) - .zip(pos_num_iter) - .map(|((string, characters), start_pos)| { - match (string, characters, start_pos) { - (Some(string), Some(characters), Some(start_pos)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = characters_len as i64; - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) + let flag = ScalarFlags::try_create(args)?; + + invoke_function!( + flag, + |arg0, arg1, arg2| overlay3::(arg0, arg1, arg2), + string_array, + characters_array, + pos_num + ) } 4 => { let string_array = as_generic_string_array::(&args[0])?; let characters_array = as_generic_string_array::(&args[1])?; let pos_num = as_int64_array(&args[2])?; let len_num = as_int64_array(&args[3])?; - - let characters_array_iter = adaptive_array_iter(characters_array.iter()); - let pos_num_iter = adaptive_array_iter(pos_num.iter()); - let len_num_iter = adaptive_array_iter(len_num.iter()); - - let result = string_array - .iter() - .zip(characters_array_iter) - .zip(pos_num_iter) - .zip(len_num_iter) - .map(|(((string, characters), start_pos), len)| { - match (string, characters, start_pos, len) { - (Some(string), Some(characters), Some(start_pos), Some(len)) => { - let string_len = string.chars().count(); - let characters_len = characters.chars().count(); - let replace_len = len.min(string_len as i64); - let mut res = - String::with_capacity(string_len.max(characters_len)); - - //as sql replace index start from 1 while string index start from 0 - if start_pos > 1 && start_pos - 1 < string_len as i64 { - let start = (start_pos - 1) as usize; - res.push_str(&string[..start]); - } - res.push_str(characters); - // if start + replace_len - 1 >= string_length, just to string end - if start_pos + replace_len - 1 < string_len as i64 { - let end = (start_pos + replace_len - 1) as usize; - res.push_str(&string[end..]); - } - Ok(Some(res)) - } - _ => Ok(None), - } - }) - .collect::>>()?; - Ok(Arc::new(result) as ArrayRef) + let flag = ScalarFlags::try_create(args)?; + + invoke_function!( + flag, + |arg0, arg1, arg2, arg3| overlay4::(arg0, arg1, arg2, arg3), + string_array, + characters_array, + pos_num, + len_num + ) } other => { exec_err!("overlay was called with {other} arguments. It requires 3 or 4.") diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 2431d77c564d..fcdfecde23e8 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ArrayAccessor, ArrayIter, ArrayRef}; +use arrow::array::ArrayRef; use arrow::datatypes::DataType; -use datafusion_common::{Result, ScalarValue}; +use bitflags::bitflags; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; use datafusion_physical_expr::functions::Hint; -use itertools::Either; use std::sync::Arc; /// Creates a function to identify the optimal return type of a string function given @@ -117,30 +117,6 @@ where }) } -/// Create an adaptive iterator. When the `hints` args of `make_scalar_function` -/// includes `Hint::AcceptsSingular`, this function can be used to wrap an `ArrayIter` -/// that contains only one value into an `Iterator` that contains multiple values, -/// facilitating `zip` operations. -/// NOTE: -/// 1. When using this function, be sure to ensure that the corresponding `Hint` for -/// `array_iter` must be `Hint::AcceptsSingular`. -/// 2. You cannot call this function on all `args` of `inner` of `make_scalar_function` -/// at the same time; there is a risk of never being able to exit the iteration! -pub(super) fn adaptive_array_iter<'a, T>( - mut array_iter: ArrayIter, -) -> impl Iterator> + 'a -where - T: ArrayAccessor + 'a, - T::Item: Copy, -{ - if array_iter.len() == 1 { - let value = array_iter.next().expect("Contains a value"); - Either::Left(std::iter::repeat(value)) - } else { - Either::Right(array_iter) - } -} - #[cfg(test)] pub mod test { /// $FUNC ScalarUDFImpl to test @@ -204,3 +180,49 @@ pub mod test { pub(crate) use test_function; } + +bitflags! { + /// Represents the position of the Scalar in args. + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] + pub(crate) struct ScalarFlags: u8 { + // There are no Scalars in the args + const None = 0b00000000; + + // Represents the 1th arg is Scalar + const A = 0b00000001; + // Represents the 2th arg is Scalar + const B = 0b00000010; + // Represents the 3th arg is Scalar + const C = 0b00000100; + // Represents the 4th arg is Scalar + const D = 0b00001000; + + const AB = Self::A.bits() | Self::B.bits(); + const AC = Self::A.bits() | Self::C.bits(); + const AD = Self::A.bits() | Self::D.bits(); + const BC = Self::B.bits() | Self::C.bits(); + const BD = Self::B.bits() | Self::D.bits(); + const CD = Self::C.bits() | Self::D.bits(); + + const ABC = Self::A.bits() | Self::B.bits() | Self::C.bits(); + const ABD = Self::A.bits() | Self::B.bits() | Self::D.bits(); + const ACD = Self::A.bits() | Self::C.bits() | Self::D.bits(); + const BCD = Self::B.bits() | Self::C.bits() | Self::D.bits(); + + const ABCD = Self::A.bits() | Self::B.bits() | Self::C.bits() | Self::D.bits(); + } +} + +impl ScalarFlags { + pub fn try_create(args: &[ArrayRef]) -> Result { + let mut flag: u8 = 0; + args.iter().enumerate().for_each(|(i, arg)| { + if arg.len() == 1 { + flag |= 1 << i; + } + }); + Self::from_bits(flag).ok_or_else(|| { + DataFusionError::Execution(format!("Unsupported ScalarFlags: {}", flag)) + }) + } +}