Skip to content
48 changes: 38 additions & 10 deletions core/src/execution/datafusion/expressions/scalar_funcs/chr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::{
};

use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use datafusion_common::{cast::as_int64_array, exec_err, DataFusionError, Result};
use datafusion_common::{cast::as_int64_array, exec_err, DataFusionError, Result, ScalarValue};

/// Returns the ASCII character having the binary equivalent to the input expression.
/// E.g., chr(65) = 'A'.
Expand Down Expand Up @@ -94,15 +94,43 @@ impl ScalarUDFImpl for ChrFunc {
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let array = args[0].clone();
match array {
ColumnarValue::Array(array) => {
let array = chr(&[array])?;
Ok(ColumnarValue::Array(array))
}
_ => {
exec_err!("The first argument must be an array, but got: {:?}", array)
}
make_scalar_function(chr)(args)
}
}

/// The make_scalar_function function is a higher-order function that:
/// - Takes a function inner designed to operate on arrays.
/// - Wraps this function in a closure that can accept a mix of scalar and array inputs.
/// - Converts scalar inputs to arrays, calls the inner function, and then converts the result back to a scalar if the original inputs were all scalars.
///
/// taken from datafusion utils

fn make_scalar_function<F>(inner: F) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
where
F: Fn(&[ArrayRef]) -> Result<ArrayRef> + Sync + Send + 'static,
{
move |args: &[ColumnarValue]| {
// first, identify if any of the arguments is an Array. If yes, store its `len`,
// as any scalar will need to be converted to an array of len `len`.
let len = args
.iter()
.fold(Option::<usize>::None, |acc, arg| match arg {
ColumnarValue::Scalar(_) => acc,
ColumnarValue::Array(a) => Some(a.len()),
});

let is_scalar = len.is_none();

let args = ColumnarValue::values_to_arrays(args)?;

let result = (inner)(&args);

if is_scalar {
// If all inputs are scalar, keeps output as scalar
let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
result.map(ColumnarValue::Scalar)
} else {
result.map(ColumnarValue::Array)
}
}
}