Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,7 @@ fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
(Dictionary(_, lhs_value_type), Dictionary(_, rhs_value_type)) => {
string_coercion(lhs_value_type, rhs_value_type).or(None)
}
(Binary, Binary) => Some(Utf8),
_ => None,
})
}
Expand Down
118 changes: 84 additions & 34 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ use crate::string::concat;
use crate::strings::{
ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
};
use datafusion_common::cast::{as_string_array, as_string_view_array};
use datafusion_common::{Result, ScalarValue, internal_err, plan_err};
use datafusion_common::cast::{as_binary_array, as_string_array, as_string_view_array};
use datafusion_common::{
Result, ScalarValue, exec_datafusion_err, internal_err, plan_err,
};
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyContext};
use datafusion_expr::{ColumnarValue, Documentation, Expr, Volatility, lit};
Expand Down Expand Up @@ -68,13 +70,24 @@ impl ConcatFunc {
use DataType::*;
Self {
signature: Signature::variadic(
vec![Utf8View, Utf8, LargeUtf8],
vec![Utf8View, Utf8, LargeUtf8, Binary],
Volatility::Immutable,
),
}
}
}

fn deduce_return_type(arg_types: &[DataType]) -> DataType {
use DataType::*;
if arg_types.contains(&Utf8View) {
Utf8View
} else if arg_types.contains(&LargeUtf8) {
LargeUtf8
} else {
Utf8
}
}

impl ScalarUDFImpl for ConcatFunc {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -92,29 +105,16 @@ impl ScalarUDFImpl for ConcatFunc {
/// mixed inputs, prefer Utf8View; prefer LargeUtf8 over Utf8 to avoid
/// potential overflow on LargeUtf8 input.
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
use DataType::*;
if arg_types.contains(&Utf8View) {
Ok(Utf8View)
} else if arg_types.contains(&LargeUtf8) {
Ok(LargeUtf8)
} else {
Ok(Utf8)
}
Ok(deduce_return_type(arg_types))
}

/// Concatenates the text representations of all the arguments. NULL arguments are ignored.
/// concat('abcde', 2, NULL, 22) = 'abcde222'
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let ScalarFunctionArgs { args, .. } = args;

let return_datatype = if args.iter().any(|c| c.data_type() == DataType::Utf8View)
{
DataType::Utf8View
} else if args.iter().any(|c| c.data_type() == DataType::LargeUtf8) {
DataType::LargeUtf8
} else {
DataType::Utf8
};
let arg_types: Vec<DataType> = args.iter().map(|c| c.data_type()).collect();
let return_datatype = deduce_return_type(&arg_types);

let array_len = args.iter().find_map(|x| match x {
ColumnarValue::Array(array) => Some(array.len()),
Expand All @@ -123,22 +123,28 @@ impl ScalarUDFImpl for ConcatFunc {

// Scalar
if array_len.is_none() {
let mut values = Vec::with_capacity(args.len());
let mut values: Vec<&[u8]> = Vec::with_capacity(args.len());
for arg in &args {
let ColumnarValue::Scalar(scalar) = arg else {
return internal_err!("concat expected scalar value, got {arg:?}");
};

match scalar.try_as_str() {
Some(Some(v)) => values.push(v),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {}",
scalar
)?,
if let ScalarValue::Binary(Some(value)) = scalar {
values.push(value);
} else {
match scalar.try_as_str() {
Some(Some(v)) => values.push(v.as_bytes()),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {}",
scalar
)?,
}
}
}
let result = values.concat();
let concat_bytes = values.concat();
let result = std::str::from_utf8(&concat_bytes)
.map_err(|_| exec_datafusion_err!("invalid UTF-8 in binary literal"))?
.to_string();

return match return_datatype {
DataType::Utf8View => {
Expand Down Expand Up @@ -171,6 +177,13 @@ impl ScalarUDFImpl for ConcatFunc {
columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
}
}
ColumnarValue::Scalar(ScalarValue::Binary(maybe_value)) => {
if let Some(b) = maybe_value {
// data_size is a capacity hint, so doesn't matter if it is chars or bytes
data_size += b.len() * len;
columns.push(ColumnarValueRef::Scalar(b.as_slice()));
}
}
ColumnarValue::Array(array) => {
match array.data_type() {
DataType::Utf8 => {
Expand Down Expand Up @@ -210,6 +223,17 @@ impl ScalarUDFImpl for ConcatFunc {
};
columns.push(column);
}
DataType::Binary => {
let string_array = as_binary_array(array)?;

data_size += string_array.values().len();
let column = if array.is_nullable() {
ColumnarValueRef::NullableBinaryArray(string_array)
} else {
ColumnarValueRef::NonNullableBinaryArray(string_array)
};
columns.push(column);
}
other => {
return plan_err!(
"Input was {other} which is not a supported datatype for concat function"
Expand All @@ -231,7 +255,7 @@ impl ScalarUDFImpl for ConcatFunc {
builder.append_offset();
}

let string_array = builder.finish(None);
let string_array = builder.finish(None)?;
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
DataType::Utf8View => {
Expand All @@ -240,10 +264,10 @@ impl ScalarUDFImpl for ConcatFunc {
columns
.iter()
.for_each(|column| builder.write::<true>(column, i));
builder.append_offset();
builder.append_offset()?;
}

let string_array = builder.finish(None);
let string_array = builder.finish(None)?;
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
DataType::LargeUtf8 => {
Expand All @@ -255,7 +279,7 @@ impl ScalarUDFImpl for ConcatFunc {
builder.append_offset();
}

let string_array = builder.finish(None);
let string_array = builder.finish(None)?;
Ok(ColumnarValue::Array(Arc::new(string_array)))
}
_ => unreachable!(),
Expand Down Expand Up @@ -451,7 +475,33 @@ mod tests {
Utf8View,
StringViewArray
);

test_function!(
ConcatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Binary(Some(
"Café".as_bytes().into()
))),
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
],
Ok(Some("Cafécc")),
&str,
Utf8,
StringArray
);
test_function!(
ConcatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::Binary(Some(Vec::from(
"Café".as_bytes()
)))),
ColumnarValue::Scalar(ScalarValue::Binary(Some("cc".as_bytes().into()))),
],
Ok(Some("Cafécc")),
&str,
Utf8,
StringArray
);
Ok(())
}

Expand Down
10 changes: 5 additions & 5 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
for i in 0..len {
if !sep.is_valid(i) {
builder.append_offset();
builder.append_offset()?;
continue;
}
let mut first = true;
Expand All @@ -332,9 +332,9 @@ impl ScalarUDFImpl for ConcatWsFunc {
first = false;
}
}
builder.append_offset();
builder.append_offset()?;
}
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?)))
}
DataType::LargeUtf8 => {
let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
Expand All @@ -355,7 +355,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
}
builder.append_offset();
}
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?)))
}
_ => {
let mut builder = StringArrayBuilder::with_capacity(len, data_size);
Expand All @@ -376,7 +376,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
}
builder.append_offset();
}
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls()))))
Ok(ColumnarValue::Array(Arc::new(builder.finish(sep.nulls())?)))
}
}
}
Expand Down
Loading
Loading