Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,7 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
string_coercion(lhs_value_type, rhs_value_type).or(None)
}
(Binary, Binary) => Some(Utf8),
_ => None,
})
}
Expand Down
118 changes: 84 additions & 34 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ use crate::string::concat;
use crate::strings::{
ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
};
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array};
use datafusion_common::{
Result, ScalarValue, exec_datafusion_err, internal_err, plan_err,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
Expand Down Expand Up @@ -68,13 +70,24 @@ impl ConcatFunc {
use DataType::*;
Self {
signature: Signature::variadic(
vec![Utf8View, Utf8, LargeUtf8],
vec![Utf8View, Utf8, LargeUtf8, Binary],
Volatility::Immutable,
),
}
}
}

fn deduce_return_type(arg_types: &[DataType]) -> DataType {
use DataType::*;
if arg_types.contains(&Utf8View) {
Utf8View
} else if arg_types.contains(&LargeUtf8) {
LargeUtf8
} else {
Utf8
}
}

impl ScalarUDFImpl for ConcatFunc {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -92,29 +105,16 @@ impl ScalarUDFImpl for ConcatFunc {
/// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid
/// potential overflow on LargeUtf8 input.
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
if arg_types.contains(&Utf8View) {
Ok(Utf8View)
} else if arg_types.contains(&LargeUtf8) {
Ok(LargeUtf8)
} else {
Ok(Utf8)
}
Ok(deduce_return_type(arg_types))
}

/// Concatenates the text representations of all the arguments. NULL arguments are ignored.
/// concat('abcde', 2, NULL, 22) = 'abcde222'
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;

let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View)
{
DataType::Utf8View
} else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) {
DataType::LargeUtf8
} else {
DataType::Utf8
};
let arg_types: Vec<DataType> = args.iter().map(|c| c.data_type()).collect();
let return_datatype = deduce_return_type(&arg_types);

let array_len = args.iter().find_map(|x| match x {
ColumnarValue::Array(array) => Some(array.len()),
Expand All @@ -123,22 +123,28 @@ impl ScalarUDFImpl for ConcatFunc {

// Scalar
if array_len.is_none() {
let mut values = Vec::with_capacity(args.len());
let mut values: Vec<&[u8]> = Vec::with_capacity(args.len());
for arg in &args {
let ColumnarValue::Scalar(scalar) = arg else {
return internal_err!("concat expected scalar value, got {arg:?}");
};

match scalar.try_as_str() {
Some(Some(v)) => values.push(v),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {}",
scalar
)?,
if let ScalarValue::Binary(Some(value)) = scalar {
values.push(value);
} else {
match scalar.try_as_str() {
Some(Some(v)) => values.push(v.as_bytes()),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {}",
scalar
)?,
}
}
}
let result = values.concat();
let concat_bytes = values.concat();
let result = std::str::from_utf8(&concat_bytes)
.map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))?
.to_string();

return match return_datatype {
DataType::Utf8View => {
Expand Down Expand Up @@ -171,6 +177,13 @@ impl ScalarUDFImpl for ConcatFunc {
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => {
if let Some(b) = maybe_value {
// data_size is a capacity hint, so doesn't matter if it is chars or bytes
data_size += b.len() * len;
columns.push(ColumnarValueRef::Scalar(b.as_slice()));
}
}
ColumnarValue::Array(array) => {
match array.data_type() {
DataType::Utf8 => {
Expand Down Expand Up @@ -210,6 +223,17 @@ impl ScalarUDFImpl for ConcatFunc {
};
columns.push(column);
}
DataType::Binary => {
let string_array = as_binary_array(array)?;

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableBinaryArray(string_array)
} else {
ColumnarValueRef::NonNullableBinaryArray(string_array)
};
columns.push(column);
}
other => {
return plan_err!(
"Input was {other} which is not a supported datatype for concat function"
Expand All @@ -231,7 +255,7 @@ impl ScalarUDFImpl for ConcatFunc {
builder.append_offset();
}

let string_array = builder.finish(None);
let string_array = builder.finish(None)?;
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
DataType::Utf8View => {
Expand All @@ -240,10 +264,10 @@ impl ScalarUDFImpl for ConcatFunc {
columns
.iter()
.for_each(|column| builder.write::<true>(column, i));
builder.append_offset();
builder.append_offset()?;
}

let string_array = builder.finish(None);
let string_array = builder.finish(None)?;
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
DataType::LargeUtf8 => {
Expand All @@ -255,7 +279,7 @@ impl ScalarUDFImpl for ConcatFunc {
builder.append_offset();
}

let string_array = builder.finish(None);
let string_array = builder.finish(None)?;
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
_ => unreachable!(),
Expand Down Expand Up @@ -451,7 +475,33 @@ mod tests {
Utf8View,
StringViewArray
);

test_function!(
ConcatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Binary(Some(
"Café".as_bytes().into()
))),
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
],
Ok(Some("Cafécc")),
&str,
Utf8,
StringArray
);
test_function!(
ConcatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from(
"Café".as_bytes()
)))),
ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))),
],
Ok(Some("Cafécc")),
&str,
Utf8,
StringArray
);
Ok(())
}

Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
for i in 0..len {
if !sep.is_valid(i) {
builder.append_offset();
builder.append_offset()?;
continue;
}
let mut first = true;
Expand All @@ -332,9 +332,9 @@ impl ScalarUDFImpl for ConcatWsFunc {
first = false;
}
}
builder.append_offset();
builder.append_offset()?;
}
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?)))
}
DataType::LargeUtf8 => {
let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
Expand All @@ -355,7 +355,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
}
builder.append_offset();
}
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?)))
}
_ => {
let mut builder = StringArrayBuilder::with_capacity(len, data_size);
Expand All @@ -376,7 +376,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
}
builder.append_offset();
}
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?)))
}
}
}
Expand Down
Loading
Loading