diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 24b831e7c575..f9feb2c9a114 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -76,6 +76,7 @@ sha2 = { version = "^0.10.1", optional = true } unicode-segmentation = { version = "^1.7.1", optional = true } [dev-dependencies] +arrow = { workspace = true, features = ["test_utils"] } criterion = "0.5" rand = { workspace = true } rstest = { workspace = true } @@ -84,3 +85,7 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } [[bench]] harness = false name = "in_list" + +[[bench]] +harness = false +name = "concat" diff --git a/datafusion/physical-expr/benches/concat.rs b/datafusion/physical-expr/benches/concat.rs new file mode 100644 index 000000000000..cdd54d767f1f --- /dev/null +++ b/datafusion/physical-expr/benches/concat.rs @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::bench_util::create_string_array_with_len; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::ScalarValue; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::string_expressions::concat; +use std::sync::Arc; + +fn create_args(size: usize, str_len: usize) -> Vec { + let array = Arc::new(create_string_array_with_len::(size, 0.2, str_len)); + let scalar = ScalarValue::Utf8(Some(", ".to_string())); + vec![ + ColumnarValue::Array(array.clone()), + ColumnarValue::Scalar(scalar), + ColumnarValue::Array(array), + ] +} + +fn criterion_benchmark(c: &mut Criterion) { + for size in [1024, 4096, 8192] { + let args = create_args(size, 32); + let mut group = c.benchmark_group("concat function"); + group.bench_function(BenchmarkId::new("concat", size), |b| { + b.iter(|| criterion::black_box(concat(&args).unwrap())) + }); + group.finish(); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 515511b15fbb..f609a6e9f01c 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -253,9 +253,9 @@ pub fn create_physical_fun( // string functions BuiltinScalarFunction::Coalesce => Arc::new(conditional_expressions::coalesce), BuiltinScalarFunction::Concat => Arc::new(string_expressions::concat), - BuiltinScalarFunction::ConcatWithSeparator => Arc::new(|args| { - make_scalar_function_inner(string_expressions::concat_ws)(args) - }), + BuiltinScalarFunction::ConcatWithSeparator => { + Arc::new(string_expressions::concat_ws) + } BuiltinScalarFunction::InitCap => Arc::new(|args| match args[0].data_type() { DataType::Utf8 => { make_scalar_function_inner(string_expressions::initcap::)(args) diff --git a/datafusion/physical-expr/src/string_expressions.rs b/datafusion/physical-expr/src/string_expressions.rs index 2185b7c5b4a1..fd6c8eb6b1d9 100644 --- a/datafusion/physical-expr/src/string_expressions.rs +++ b/datafusion/physical-expr/src/string_expressions.rs @@ -23,6 +23,7 @@ use std::sync::Arc; +use arrow::array::ArrayDataBuilder; use arrow::{ array::{ Array, ArrayRef, GenericStringArray, Int32Array, Int64Array, OffsetSizeTrait, @@ -30,6 +31,7 @@ use arrow::{ }, datatypes::DataType, }; +use arrow_buffer::{MutableBuffer, NullBuffer}; use datafusion_common::Result; use datafusion_common::{ @@ -38,75 +40,153 @@ use datafusion_common::{ }; use datafusion_expr::ColumnarValue; +enum ColumnarValueRef<'a> { + Scalar(&'a [u8]), + NullableArray(&'a StringArray), + NonNullableArray(&'a StringArray), +} + +impl<'a> ColumnarValueRef<'a> { + #[inline] + fn is_valid(&self, i: usize) -> bool { + match &self { + Self::Scalar(_) | Self::NonNullableArray(_) => true, + Self::NullableArray(array) => array.is_valid(i), + } + } + + #[inline] + fn nulls(&self) -> Option { + match &self { + Self::Scalar(_) | Self::NonNullableArray(_) => None, + Self::NullableArray(array) => array.nulls().cloned(), + } + } +} + +/// Optimized version of the StringBuilder in Arrow that: +/// 1. Precalculating the expected length of the result, avoiding reallocations. +/// 2. Avoids creating / incrementally creating a `NullBufferBuilder` +struct StringArrayBuilder { + offsets_buffer: MutableBuffer, + value_buffer: MutableBuffer, +} + +impl StringArrayBuilder { + fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { + let mut offsets_buffer = MutableBuffer::with_capacity( + (item_capacity + 1) * std::mem::size_of::(), + ); + // SAFETY: the first offset value is definitely not going to exceed the bounds. + unsafe { offsets_buffer.push_unchecked(0_i32) }; + Self { + offsets_buffer, + value_buffer: MutableBuffer::with_capacity(data_capacity), + } + } + + fn write(&mut self, column: &ColumnarValueRef, i: usize) { + match column { + ColumnarValueRef::Scalar(s) => { + self.value_buffer.extend_from_slice(s); + } + ColumnarValueRef::NullableArray(array) => { + if !CHECK_VALID || array.is_valid(i) { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + ColumnarValueRef::NonNullableArray(array) => { + self.value_buffer + .extend_from_slice(array.value(i).as_bytes()); + } + } + } + + fn append_offset(&mut self) { + let next_offset: i32 = self + .value_buffer + .len() + .try_into() + .expect("byte array offset overflow"); + unsafe { self.offsets_buffer.push_unchecked(next_offset) }; + } + + fn finish(self, null_buffer: Option) -> StringArray { + let array_builder = ArrayDataBuilder::new(DataType::Utf8) + .len(self.offsets_buffer.len() / std::mem::size_of::() - 1) + .add_buffer(self.offsets_buffer.into()) + .add_buffer(self.value_buffer.into()) + .nulls(null_buffer); + // SAFETY: all data that was appended was valid UTF8 and the values + // and offsets were created correctly + let array_data = unsafe { array_builder.build_unchecked() }; + StringArray::from(array_data) + } +} + /// Concatenates the text representations of all the arguments. NULL arguments are ignored. /// concat('abcde', 2, NULL, 22) = 'abcde222' pub fn concat(args: &[ColumnarValue]) -> Result { - // do not accept 0 arguments. - if args.is_empty() { - return exec_err!( - "concat was called with {} arguments. It requires at least 1.", - args.len() - ); + let array_len = args + .iter() + .filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .next(); + + // Scalar + if array_len.is_none() { + let mut result = String::new(); + for arg in args { + if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg { + result.push_str(v); + } + } + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); } - // first, decide whether to return a scalar or a vector. - let mut return_array = args.iter().filter_map(|x| match x { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }); - if let Some(size) = return_array.next() { - let result = (0..size) - .map(|index| { - let mut owned_string: String = "".to_owned(); - for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { - if let Some(value) = maybe_value { - owned_string.push_str(value); - } - } - ColumnarValue::Array(v) => { - if v.is_valid(index) { - let v = as_string_array(v).unwrap(); - owned_string.push_str(v.value(index)); - } - } - _ => unreachable!(), - } + // Array + let len = array_len.unwrap(); + let mut data_size = 0; + let mut columns = Vec::with_capacity(args.len()); + + for arg in args { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); } - Some(owned_string) - }) - .collect::(); - - Ok(ColumnarValue::Array(Arc::new(result))) - } else { - // short avenue with only scalars - let initial = Some("".to_string()); - let result = args.iter().fold(initial, |mut acc, rhs| { - if let Some(ref mut inner) = acc { - match rhs { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) => { - inner.push_str(v); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} - _ => unreachable!(""), + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) }; - }; - acc - }); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + columns.push(column); + } + _ => unreachable!(), + } + } + + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + columns + .iter() + .for_each(|column| builder.write::(column, i)); + builder.append_offset(); } + Ok(ColumnarValue::Array(Arc::new(builder.finish(None)))) } /// Concatenates all but the first argument, with separators. The first argument is used as the separator string, and should not be NULL. Other NULL arguments are ignored. /// concat_ws(',', 'abcde', 2, NULL, 22) = 'abcde,2,22' -pub fn concat_ws(args: &[ArrayRef]) -> Result { - // downcast all arguments to strings - let args = args - .iter() - .map(|e| as_string_array(e)) - .collect::>>()?; - +pub fn concat_ws(args: &[ColumnarValue]) -> Result { // do not accept 0 or 1 arguments. if args.len() < 2 { return exec_err!( @@ -115,28 +195,126 @@ pub fn concat_ws(args: &[ArrayRef]) -> Result { ); } - // first map is the iterator, second is for the `Option<_>` - let result = args[0] + let array_len = args .iter() - .enumerate() - .map(|(index, x)| { - x.map(|sep: &str| { - let string_vec = args[1..] - .iter() - .flat_map(|arg| { - if !arg.is_null(index) { - Some(arg.value(index)) - } else { - None - } - }) - .collect::>(); - string_vec.join(sep) - }) + .filter_map(|x| match x { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, }) - .collect::(); + .next(); - Ok(Arc::new(result) as ArrayRef) + // Scalar + if array_len.is_none() { + let sep = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s, + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); + } + _ => unreachable!(), + }; + + let mut result = String::new(); + let iter = &mut args[1..].iter(); + + for arg in iter.by_ref() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + result.push_str(s); + break; + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + _ => unreachable!(), + } + } + + for arg in iter.by_ref() { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + result.push_str(sep); + result.push_str(s); + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {} + _ => unreachable!(), + } + } + + return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result)))); + } + + // Array + let len = array_len.unwrap(); + let mut data_size = 0; + + // parse sep + let sep = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => { + data_size += s.len() * len * (args.len() - 2); // estimate + ColumnarValueRef::Scalar(s.as_bytes()) + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + return Ok(ColumnarValue::Array(Arc::new(StringArray::new_null(len)))); + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len() * (args.len() - 2); // estimate + if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + } + } + _ => unreachable!(), + }; + + let mut columns = Vec::with_capacity(args.len() - 1); + for arg in &args[1..] { + match arg { + ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value)) => { + if let Some(s) = maybe_value { + data_size += s.len() * len; + columns.push(ColumnarValueRef::Scalar(s.as_bytes())); + } + } + ColumnarValue::Array(array) => { + let string_array = as_string_array(array)?; + data_size += string_array.values().len(); + let column = if array.is_nullable() { + ColumnarValueRef::NullableArray(string_array) + } else { + ColumnarValueRef::NonNullableArray(string_array) + }; + columns.push(column); + } + _ => unreachable!(), + } + } + + let mut builder = StringArrayBuilder::with_capacity(len, data_size); + for i in 0..len { + if !sep.is_valid(i) { + builder.append_offset(); + continue; + } + + let mut iter = columns.iter(); + for column in iter.by_ref() { + if column.is_valid(i) { + builder.write::(column, i); + break; + } + } + + for column in iter { + if column.is_valid(i) { + builder.write::(&sep, i); + builder.write::(column, i); + } + } + + builder.append_offset(); + } + + Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())))) } /// Converts the first letter of each word to upper case and the rest to lower case. Words are sequences of alphanumeric characters separated by non-alphanumeric characters. @@ -234,3 +412,84 @@ pub fn ends_with(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn concat() -> Result<()> { + let c0 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = super::concat(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + Ok(()) + } + + #[test] + fn concat_ws() -> Result<()> { + // sep is scalar + let c0 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string()))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + None, + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = super::concat_ws(args)?; + let expected = + Arc::new(StringArray::from(vec!["foo,x", "bar", "baz,z"])) as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + + // sep is nullable array + let c0 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some(","), + None, + Some("+"), + ]))); + let c1 = + ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"]))); + let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("x"), + Some("y"), + Some("z"), + ]))); + let args = &[c0, c1, c2]; + + let result = super::concat_ws(args)?; + let expected = + Arc::new(StringArray::from(vec![Some("foo,x"), None, Some("baz+z")])) + as ArrayRef; + match &result { + ColumnarValue::Array(array) => { + assert_eq!(&expected, array); + } + _ => panic!(), + } + + Ok(()) + } +}