From 37ce4eb16b430f73b6abc645a3d71632b67e516e Mon Sep 17 00:00:00 2001 From: Cyprien Huet Date: Thu, 8 Jan 2026 18:12:02 +0400 Subject: [PATCH 1/3] feat(spark): implement array_repeat function --- datafusion/spark/src/function/array/mod.rs | 9 +- datafusion/spark/src/function/array/repeat.rs | 128 ++++++++++++++++++ datafusion/spark/src/function/mod.rs | 1 + datafusion/spark/src/function/null_utils.rs | 122 +++++++++++++++++ .../spark/src/function/string/concat.rs | 110 +-------------- .../test_files/spark/array/array_repeat.slt | 50 +++++-- 6 files changed, 306 insertions(+), 114 deletions(-) create mode 100644 datafusion/spark/src/function/array/repeat.rs create mode 100644 datafusion/spark/src/function/null_utils.rs diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index 01056ba952984..7140653510e09 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod repeat; pub mod shuffle; pub mod spark_array; @@ -24,6 +25,7 @@ use std::sync::Arc; make_udf_function!(spark_array::SparkArray, array); make_udf_function!(shuffle::SparkShuffle, shuffle); +make_udf_function!(repeat::SparkArrayRepeat, array_repeat); pub mod expr_fn { use datafusion_functions::export_functions; @@ -34,8 +36,13 @@ pub mod expr_fn { "Returns a random permutation of the given array.", args )); + export_functions!(( + array_repeat, + "returns an array containing element count times.", + element count + )); } pub fn functions() -> Vec> { - vec![array(), shuffle()] + vec![array(), shuffle(), array_repeat()] } diff --git a/datafusion/spark/src/function/array/repeat.rs b/datafusion/spark/src/function/array/repeat.rs new file mode 100644 index 0000000000000..7543300a91078 --- /dev/null +++ b/datafusion/spark/src/function/array/repeat.rs @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue, exec_err}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use datafusion_functions_nested::repeat::ArrayRepeat; +use std::any::Any; +use std::sync::Arc; + +use crate::function::null_utils::{ + NullMaskResolution, apply_null_mask, compute_null_mask, +}; + +/// Spark-compatible `array_repeat` expression. The difference with DataFusion's `array_repeat` is the handling of NULL inputs: in spark if any input is NULL, the result is NULL. +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkArrayRepeat { + signature: Signature, +} + +impl Default for SparkArrayRepeat { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayRepeat { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkArrayRepeat { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_repeat" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(DataType::List(Arc::new(Field::new_list_field( + arg_types[0].clone(), + true, + )))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + spark_array_repeat(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [first_type, second_type] = take_function_args(self.name(), arg_types)?; + + // Coerce the second argument to Int64/UInt64 if it's a numeric type + let second = match second_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + _ => return exec_err!("count must be an integer type"), + }; + + Ok(vec![first_type.clone(), second]) + } +} + +/// This is a Spark-specific wrapper around DataFusion's array_repeat that returns NULL +/// if any argument is NULL (Spark behavior), whereas DataFusion's array_repeat ignores NULLs. +fn spark_array_repeat(args: ScalarFunctionArgs) -> Result { + let ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + } = args; + let return_type = return_field.data_type().clone(); + + // Step 1: Check for NULL mask in incoming args + let null_mask = compute_null_mask(&arg_values, number_rows)?; + + // If any argument is null then return NULL immediately + if matches!(null_mask, NullMaskResolution::ReturnNull) { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)); + } + + // Step 2: Delegate to DataFusion's array_repeat + let array_repeat_func = ArrayRepeat::new(); + let func_args = ScalarFunctionArgs { + args: arg_values, + arg_fields, + number_rows, + return_field, + config_options, + }; + let result = array_repeat_func.invoke_with_args(func_args)?; + + // Step 3: Apply NULL mask to result + apply_null_mask(result, null_mask, &return_type) +} diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs index 3f4f94cfaaf8c..bebf4d2efa8a4 100644 --- a/datafusion/spark/src/function/mod.rs +++ b/datafusion/spark/src/function/mod.rs @@ -33,6 +33,7 @@ pub mod lambda; pub mod map; pub mod math; pub mod misc; +pub mod null_utils; pub mod predicate; pub mod string; pub mod r#struct; diff --git a/datafusion/spark/src/function/null_utils.rs b/datafusion/spark/src/function/null_utils.rs new file mode 100644 index 0000000000000..b25dc07d0e525 --- /dev/null +++ b/datafusion/spark/src/function/null_utils.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::Array; +use arrow::buffer::NullBuffer; +use arrow::datatypes::DataType; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +pub(crate) enum NullMaskResolution { + /// Return NULL as the result (e.g., scalar inputs with at least one NULL) + ReturnNull, + /// No null mask needed (e.g., all scalar inputs are non-NULL) + NoMask, + /// Null mask to apply for arrays + Apply(NullBuffer), +} + +/// Compute NULL mask for the arguments using NullBuffer::union +pub(crate) fn compute_null_mask( + args: &[ColumnarValue], + number_rows: usize, +) -> Result { + // Check if all arguments are scalars + let all_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + + if all_scalars { + // For scalars, check if any is NULL + for arg in args { + if let ColumnarValue::Scalar(scalar) = arg + && scalar.is_null() + { + return Ok(NullMaskResolution::ReturnNull); + } + } + // No NULLs in scalars + Ok(NullMaskResolution::NoMask) + } else { + // For arrays, compute NULL mask for each row using NullBuffer::union + let array_len = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + _ => None, + }) + .unwrap_or(number_rows); + + // Convert all scalars to arrays for uniform processing + let arrays: Result> = args + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Ok(Arc::clone(array)), + ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), + }) + .collect(); + let arrays = arrays?; + + // Use NullBuffer::union to combine all null buffers + let combined_nulls = arrays + .iter() + .map(|arr| arr.nulls()) + .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); + + match combined_nulls { + Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), + None => Ok(NullMaskResolution::NoMask), + } + } +} + +/// Apply NULL mask to the result using NullBuffer::union +pub(crate) fn apply_null_mask( + result: ColumnarValue, + null_mask: NullMaskResolution, + return_type: &DataType, +) -> Result { + match (result, null_mask) { + // Scalar with ReturnNull mask means return NULL of the correct type + (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { + Ok(ColumnarValue::Scalar(ScalarValue::try_from(return_type)?)) + } + // Scalar without mask, return as-is + (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), + // Array with NULL mask - use NullBuffer::union to combine nulls + (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { + // Combine the result's existing nulls with our computed null mask + let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); + + // Create new array with combined nulls + let new_array = array + .into_data() + .into_builder() + .nulls(combined_nulls) + .build()?; + + Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + new_array, + )))) + } + // Array without NULL mask, return as-is + (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), + // Edge cases that shouldn't happen in practice + (scalar, _) => Ok(scalar), + } +} diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 8e97e591fc357..f3dae22866c23 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::Array; -use arrow::buffer::NullBuffer; use arrow::datatypes::{DataType, Field}; use datafusion_common::arrow::datatypes::FieldRef; use datafusion_common::{Result, ScalarValue}; @@ -29,6 +27,10 @@ use datafusion_functions::string::concat::ConcatFunc; use std::any::Any; use std::sync::Arc; +use crate::function::null_utils::{ + NullMaskResolution, apply_null_mask, compute_null_mask, +}; + /// Spark-compatible `concat` expression /// /// @@ -94,16 +96,6 @@ impl ScalarUDFImpl for SparkConcat { } } -/// Represents the null state for Spark concat -enum NullMaskResolution { - /// Return NULL as the result (e.g., scalar inputs with at least one NULL) - ReturnNull, - /// No null mask needed (e.g., all scalar inputs are non-NULL) - NoMask, - /// Null mask to apply for arrays - Apply(NullBuffer), -} - /// Concatenates strings, returning NULL if any input is NULL /// This is a Spark-specific wrapper around DataFusion's concat that returns NULL /// if any argument is NULL (Spark behavior), whereas DataFusion's concat ignores NULLs. @@ -133,6 +125,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { // Step 2: Delegate to DataFusion's concat let concat_func = ConcatFunc::new(); + let return_type = return_field.data_type().clone(); let func_args = ScalarFunctionArgs { args: arg_values, arg_fields, @@ -143,103 +136,14 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { let result = concat_func.invoke_with_args(func_args)?; // Step 3: Apply NULL mask to result - apply_null_mask(result, null_mask) -} - -/// Compute NULL mask for the arguments using NullBuffer::union -fn compute_null_mask( - args: &[ColumnarValue], - number_rows: usize, -) -> Result { - // Check if all arguments are scalars - let all_scalars = args - .iter() - .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); - - if all_scalars { - // For scalars, check if any is NULL - for arg in args { - if let ColumnarValue::Scalar(scalar) = arg - && scalar.is_null() - { - return Ok(NullMaskResolution::ReturnNull); - } - } - // No NULLs in scalars - Ok(NullMaskResolution::NoMask) - } else { - // For arrays, compute NULL mask for each row using NullBuffer::union - let array_len = args - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - _ => None, - }) - .unwrap_or(number_rows); - - // Convert all scalars to arrays for uniform processing - let arrays: Result> = args - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => Ok(Arc::clone(array)), - ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(array_len), - }) - .collect(); - let arrays = arrays?; - - // Use NullBuffer::union to combine all null buffers - let combined_nulls = arrays - .iter() - .map(|arr| arr.nulls()) - .fold(None, |acc, nulls| NullBuffer::union(acc.as_ref(), nulls)); - - match combined_nulls { - Some(nulls) => Ok(NullMaskResolution::Apply(nulls)), - None => Ok(NullMaskResolution::NoMask), - } - } -} - -/// Apply NULL mask to the result using NullBuffer::union -fn apply_null_mask( - result: ColumnarValue, - null_mask: NullMaskResolution, -) -> Result { - match (result, null_mask) { - // Scalar with ReturnNull mask means return NULL - (ColumnarValue::Scalar(_), NullMaskResolution::ReturnNull) => { - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))) - } - // Scalar without mask, return as-is - (scalar @ ColumnarValue::Scalar(_), NullMaskResolution::NoMask) => Ok(scalar), - // Array with NULL mask - use NullBuffer::union to combine nulls - (ColumnarValue::Array(array), NullMaskResolution::Apply(null_mask)) => { - // Combine the result's existing nulls with our computed null mask - let combined_nulls = NullBuffer::union(array.nulls(), Some(&null_mask)); - - // Create new array with combined nulls - let new_array = array - .into_data() - .into_builder() - .nulls(combined_nulls) - .build()?; - - Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( - new_array, - )))) - } - // Array without NULL mask, return as-is - (array @ ColumnarValue::Array(_), NullMaskResolution::NoMask) => Ok(array), - // Edge cases that shouldn't happen in practice - (scalar, _) => Ok(scalar), - } + apply_null_mask(result, null_mask, &return_type) } #[cfg(test)] mod tests { use super::*; use crate::function::utils::test::test_scalar_function; - use arrow::array::StringArray; + use arrow::array::{Array, StringArray}; use arrow::datatypes::{DataType, Field}; use datafusion_common::Result; use datafusion_expr::ReturnFieldArgs; diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt index 544c39608f33b..292f47845cb15 100644 --- a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt +++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt @@ -15,13 +15,43 @@ # specific language governing permissions and limitations # under the License. -# This file was originally created by a porting script from: -# https://github.com/lakehq/sail/tree/43b6ed8221de5c4c4adbedbb267ae1351158b43c/crates/sail-spark-connect/tests/gold_data/function -# This file is part of the implementation of the datafusion-spark function library. -# For more information, please see: -# https://github.com/apache/datafusion/issues/15914 - -## Original Query: SELECT array_repeat('123', 2); -## PySpark 3.5.5 Result: {'array_repeat(123, 2)': ['123', '123'], 'typeof(array_repeat(123, 2))': 'array', 'typeof(123)': 'string', 'typeof(2)': 'int'} -#query -#SELECT array_repeat('123'::string, 2::int); + +query ? +SELECT array_repeat('123', 2); +---- +[123, 123] + +query ? +SELECT array_repeat('123', 0); +---- +[] + +query ? +SELECT array_repeat('123', -1); +---- +[] + +query ? +SELECT array_repeat(['123'], 2); +---- +[[123], [123]] + +query ? +SELECT array_repeat(NULL, 2); +---- +NULL + +query ? +SELECT array_repeat([NULL], 2); +---- +[[NULL], [NULL]] + +query ? +SELECT array_repeat(['123', NULL], 2); +---- +[[123, NULL], [123, NULL]] + +query ? +SELECT array_repeat('123', CAST(NULL AS INT)); +---- +NULL From c8efe0ca61303bca9cee0c79737d6d3c476438bd Mon Sep 17 00:00:00 2001 From: Cyprien Huet Date: Fri, 9 Jan 2026 14:05:58 +0400 Subject: [PATCH 2/3] test: add array test --- .../test_files/spark/array/array_repeat.slt | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt index 292f47845cb15..04926e4c11907 100644 --- a/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt +++ b/datafusion/sqllogictest/test_files/spark/array/array_repeat.slt @@ -55,3 +55,30 @@ query ? SELECT array_repeat('123', CAST(NULL AS INT)); ---- NULL + +query ? +SELECT array_repeat(column1, column2) +FROM VALUES +('123', 2), +('123', 0), +('123', -1), +(NULL, 1), +('123', NULL); +---- +[123, 123] +[] +[] +NULL +NULL + + +query ? +SELECT array_repeat(column1, column2) +FROM VALUES +(['123'], 2), +([], 2), +([NULL], 2); +---- +[[123], [123]] +[[], []] +[[NULL], [NULL]] From b51bffb67dac32eacee0d2650baa61786d6b67d0 Mon Sep 17 00:00:00 2001 From: Cyprien Huet Date: Fri, 9 Jan 2026 14:06:15 +0400 Subject: [PATCH 3/3] fix: make null_utils non pub --- datafusion/spark/src/function/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/spark/src/function/mod.rs b/datafusion/spark/src/function/mod.rs index bebf4d2efa8a4..d5dd60c3545a5 100644 --- a/datafusion/spark/src/function/mod.rs +++ b/datafusion/spark/src/function/mod.rs @@ -33,7 +33,7 @@ pub mod lambda; pub mod map; pub mod math; pub mod misc; -pub mod null_utils; +mod null_utils; pub mod predicate; pub mod string; pub mod r#struct;