diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml index 78a1742ceedd..75c6be9b3f36 100644 --- a/datafusion/spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -45,12 +45,12 @@ datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } datafusion-functions = { workspace = true, features = ["crypto_expressions"] } log = { workspace = true } +rand = { workspace = true } sha1 = "0.10" url = { workspace = true } [dev-dependencies] criterion = { workspace = true } -rand = { workspace = true } [[bench]] harness = false diff --git a/datafusion/spark/src/function/array/mod.rs b/datafusion/spark/src/function/array/mod.rs index fed52a494281..01056ba95298 100644 --- a/datafusion/spark/src/function/array/mod.rs +++ b/datafusion/spark/src/function/array/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod shuffle; pub mod spark_array; use datafusion_expr::ScalarUDF; @@ -22,13 +23,19 @@ use datafusion_functions::make_udf_function; use std::sync::Arc; make_udf_function!(spark_array::SparkArray, array); +make_udf_function!(shuffle::SparkShuffle, shuffle); pub mod expr_fn { use datafusion_functions::export_functions; export_functions!((array, "Returns an array with the given elements.", args)); + export_functions!(( + shuffle, + "Returns a random permutation of the given array.", + args + )); } pub fn functions() -> Vec> { - vec![array()] + vec![array(), shuffle()] } diff --git a/datafusion/spark/src/function/array/shuffle.rs b/datafusion/spark/src/function/array/shuffle.rs new file mode 100644 index 000000000000..abeafd3a9366 --- /dev/null +++ b/datafusion/spark/src/function/array/shuffle.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::functions_nested_utils::make_scalar_function; +use arrow::array::{ + Array, ArrayRef, Capacities, FixedSizeListArray, GenericListArray, MutableArrayData, + OffsetSizeTrait, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::DataType::{FixedSizeList, LargeList, List, Null}; +use arrow::datatypes::{DataType, FieldRef}; +use datafusion_common::cast::{ + as_fixed_size_list_array, as_large_list_array, as_list_array, +}; +use datafusion_common::{exec_err, utils::take_function_args, Result}; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use rand::rng; +use rand::seq::SliceRandom; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkShuffle { + signature: Signature, +} + +impl Default for SparkShuffle { + fn default() -> Self { + Self::new() + } +} + +impl SparkShuffle { + pub fn new() -> Self { + Self { + signature: Signature::arrays(1, None, Volatility::Volatile), + } + } +} + +impl ScalarUDFImpl for SparkShuffle { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "shuffle" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args( + &self, + args: datafusion_expr::ScalarFunctionArgs, + ) -> Result { + make_scalar_function(array_shuffle_inner)(&args.args) + } +} + +/// array_shuffle SQL function +pub fn array_shuffle_inner(arg: &[ArrayRef]) -> Result { + let [input_array] = take_function_args("shuffle", arg)?; + match &input_array.data_type() { + List(field) => { + let array = as_list_array(input_array)?; + general_array_shuffle::(array, field) + } + LargeList(field) => { + let array = as_large_list_array(input_array)?; + general_array_shuffle::(array, field) + } + FixedSizeList(field, _) => { + let array = as_fixed_size_list_array(input_array)?; + fixed_size_array_shuffle(array, field) + } + Null => Ok(Arc::clone(input_array)), + array_type => exec_err!("shuffle does not support type '{array_type}'."), + } +} + +fn general_array_shuffle( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut offsets = vec![O::usize_as(0)]; + let mut nulls = vec![]; + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + let mut rng = rng(); + + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + // skip the null value + if array.is_null(row_index) { + nulls.push(false); + offsets.push(offsets[row_index] + O::one()); + mutable.extend(0, 0, 1); + continue; + } + nulls.push(true); + let start = offset_window[0]; + let end = offset_window[1]; + let length = (end - start).to_usize().unwrap(); + + // Create indices and shuffle them + let mut indices: Vec = + (start.to_usize().unwrap()..end.to_usize().unwrap()).collect(); + indices.shuffle(&mut rng); + + // Add shuffled elements + for &index in &indices { + mutable.extend(0, index, index + 1); + } + + offsets.push(offsets[row_index] + O::usize_as(length)); + } + + let data = mutable.freeze(); + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(field), + OffsetBuffer::::new(offsets.into()), + arrow::array::make_array(data), + Some(nulls.into()), + )?)) +} + +fn fixed_size_array_shuffle( + array: &FixedSizeListArray, + field: &FieldRef, +) -> Result { + let values = array.values(); + let original_data = values.to_data(); + let capacity = Capacities::Array(original_data.len()); + let mut nulls = vec![]; + let mut mutable = + MutableArrayData::with_capacities(vec![&original_data], false, capacity); + let value_length = array.value_length() as usize; + let mut rng = rng(); + + for row_index in 0..array.len() { + // skip the null value + if array.is_null(row_index) { + nulls.push(false); + mutable.extend(0, 0, value_length); + continue; + } + nulls.push(true); + + let start = row_index * value_length; + let end = start + value_length; + + // Create indices and shuffle them + let mut indices: Vec = (start..end).collect(); + indices.shuffle(&mut rng); + + // Add shuffled elements + for &index in &indices { + mutable.extend(0, index, index + 1); + } + } + + let data = mutable.freeze(); + Ok(Arc::new(FixedSizeListArray::try_new( + Arc::clone(field), + array.value_length(), + arrow::array::make_array(data), + Some(nulls.into()), + )?)) +} diff --git a/datafusion/sqllogictest/test_files/spark/array/shuffle.slt b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt new file mode 100644 index 000000000000..cb3c77cac8fb --- /dev/null +++ b/datafusion/sqllogictest/test_files/spark/array/shuffle.slt @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Test shuffle function with simple arrays +query B +SELECT array_sort(shuffle([1, 2, 3, 4, 5, NULL])) = [NULL,1, 2, 3, 4, 5]; +---- +true + +query B +SELECT shuffle([1, 2, 3, 4, 5, NULL]) != [1, 2, 3, 4, 5, NULL]; +---- +true + +# Test shuffle function with string arrays + +query B +SELECT array_sort(shuffle(['a', 'b', 'c', 'd', 'e', 'f'])) = ['a', 'b', 'c', 'd', 'e', 'f']; +---- +true + +query B +SELECT shuffle(['a', 'b', 'c', 'd', 'e', 'f']) != ['a', 'b', 'c', 'd', 'e', 'f'];; +---- +true + +# Test shuffle function with empty array +query ? +SELECT shuffle([]); +---- +[] + +# Test shuffle function with single element +query ? +SELECT shuffle([42]); +---- +[42] + +# Test shuffle function with null array +query ? +SELECT shuffle(NULL); +---- +NULL + +# Test shuffle function with fixed size list arrays +query B +SELECT array_sort(shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)'))) = [NULL, 1, 2, 3, 4, 5]; +---- +true + +query B +SELECT shuffle(arrow_cast([1, 2, NULL, 3, 4, 5], 'FixedSizeList(6, Int64)')) != [1, 2, NULL, 3, 4, 5]; +---- +true + +# Test shuffle on table data with different list types +statement ok +CREATE TABLE test_shuffle_list_types AS VALUES + ([1, 2, 3, 4]), + ([5, 6, 7, 8, 9]), + ([10]), + (NULL), + ([]); + +# Test shuffle with large list from table +query ? +SELECT array_sort(shuffle(column1)) FROM test_shuffle_list_types; +---- +[1, 2, 3, 4] +[5, 6, 7, 8, 9] +[10] +NULL +[] + +# Test fixed size list table +statement ok +CREATE TABLE test_shuffle_fixed_size AS VALUES + (arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)')), + (arrow_cast([4, 5, 6], 'FixedSizeList(3, Int64)')), + (arrow_cast([NULL, 8, 9], 'FixedSizeList(3, Int64)')), + (NULL); + +# Test shuffle with fixed size list from table +query ? +SELECT array_sort(shuffle(column1)) FROM test_shuffle_fixed_size; +---- +[1, 2, 3] +[4, 5, 6] +[NULL, 8, 9] +NULL + +# Clean up +statement ok +DROP TABLE test_shuffle_list_types; + +statement ok +DROP TABLE test_shuffle_fixed_size; + +