diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/array_funcs/array_insert.rs similarity index 54% rename from native/spark-expr/src/list.rs rename to native/spark-expr/src/array_funcs/array_insert.rs index fc31b11a0b..08fb789056 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/array_funcs/array_insert.rs @@ -21,14 +21,12 @@ use arrow::{ datatypes::ArrowNativeType, record_batch::RecordBatch, }; -use arrow_array::{ - make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray, -}; -use arrow_schema::{DataType, Field, FieldRef, Schema}; +use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ - cast::{as_int32_array, as_large_list_array, as_list_array}, - internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, + cast::{as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, }; use datafusion_physical_expr::PhysicalExpr; use std::hash::Hash; @@ -43,372 +41,6 @@ use std::{ // https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632; -#[derive(Debug, Eq)] -pub struct ListExtract { - child: Arc, - ordinal: Arc, - default_value: Option>, - one_based: bool, - fail_on_error: bool, -} - -impl Hash for ListExtract { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.ordinal.hash(state); - self.default_value.hash(state); - self.one_based.hash(state); - self.fail_on_error.hash(state); - } -} -impl PartialEq for ListExtract { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - && self.ordinal.eq(&other.ordinal) - && self.default_value.eq(&other.default_value) - && self.one_based.eq(&other.one_based) - && self.fail_on_error.eq(&other.fail_on_error) - } -} - -impl ListExtract { - pub fn new( - child: Arc, - ordinal: Arc, - default_value: Option>, - one_based: bool, - fail_on_error: bool, - ) -> Self { - Self { - child, - ordinal, - default_value, - one_based, - fail_on_error, - } - } - - fn child_field(&self, input_schema: &Schema) -> DataFusionResult { - match self.child.data_type(input_schema)? { - DataType::List(field) | DataType::LargeList(field) => Ok(field), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in ListExtract: {:?}", - data_type - ))), - } - } -} - -impl PhysicalExpr for ListExtract { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.child_field(input_schema)?.data_type().clone()) - } - - fn nullable(&self, input_schema: &Schema) -> DataFusionResult { - // Only non-nullable if fail_on_error is enabled and the element is non-nullable - Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable()) - } - - fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; - let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; - - let default_value = self - .default_value - .as_ref() - .map(|d| { - d.evaluate(batch).map(|value| match value { - ColumnarValue::Scalar(scalar) - if !scalar.data_type().equals_datatype(child_value.data_type()) => - { - scalar.cast_to(child_value.data_type()) - } - ColumnarValue::Scalar(scalar) => Ok(scalar), - v => Err(DataFusionError::Execution(format!( - "Expected scalar default value for ListExtract, got {:?}", - v - ))), - }) - }) - .transpose()? - .unwrap_or(self.data_type(&batch.schema())?.try_into())?; - - let adjust_index = if self.one_based { - one_based_index - } else { - zero_based_index - }; - - match child_value.data_type() { - DataType::List(_) => { - let list_array = as_list_array(&child_value)?; - let index_array = as_int32_array(&ordinal_value)?; - - list_extract( - list_array, - index_array, - &default_value, - self.fail_on_error, - adjust_index, - ) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&child_value)?; - let index_array = as_int32_array(&ordinal_value)?; - - list_extract( - list_array, - index_array, - &default_value, - self.fail_on_error, - adjust_index, - ) - } - data_type => Err(DataFusionError::Internal(format!( - "Unexpected child type for ListExtract: {:?}", - data_type - ))), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child, &self.ordinal] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - match children.len() { - 2 => Ok(Arc::new(ListExtract::new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.default_value.clone(), - self.one_based, - self.fail_on_error, - ))), - _ => internal_err!("ListExtract should have exactly two children"), - } - } -} - -fn one_based_index(index: i32, len: usize) -> DataFusionResult> { - if index == 0 { - return Err(DataFusionError::Execution( - "Invalid index of 0 for one-based ListExtract".to_string(), - )); - } - - let abs_index = index.abs().as_usize(); - if abs_index <= len { - if index > 0 { - Ok(Some(abs_index - 1)) - } else { - Ok(Some(len - abs_index)) - } - } else { - Ok(None) - } -} - -fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { - if index < 0 { - Ok(None) - } else { - let positive_index = index.as_usize(); - if positive_index < len { - Ok(Some(positive_index)) - } else { - Ok(None) - } - } -} - -fn list_extract( - list_array: &GenericListArray, - index_array: &Int32Array, - default_value: &ScalarValue, - fail_on_error: bool, - adjust_index: impl Fn(i32, usize) -> DataFusionResult>, -) -> DataFusionResult { - let values = list_array.values(); - let offsets = list_array.offsets(); - - let data = values.to_data(); - - let default_data = default_value.to_array()?.to_data(); - - let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); - - for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() { - let start = offset_window[0].as_usize(); - let len = offset_window[1].as_usize() - start; - - if let Some(i) = adjust_index(*index, len)? { - mutable.extend(0, start + i, start + i + 1); - } else if list_array.is_null(row) { - mutable.extend_nulls(1); - } else if fail_on_error { - return Err(DataFusionError::Execution( - "Index out of bounds for array".to_string(), - )); - } else { - mutable.extend(1, 0, 1); - } - } - - let data = mutable.freeze(); - Ok(ColumnarValue::Array(arrow::array::make_array(data))) -} - -impl Display for ListExtract { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", - self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error - ) - } -} - -#[derive(Debug, Eq)] -pub struct GetArrayStructFields { - child: Arc, - ordinal: usize, -} - -impl Hash for GetArrayStructFields { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.ordinal.hash(state); - } -} -impl PartialEq for GetArrayStructFields { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) - } -} - -impl GetArrayStructFields { - pub fn new(child: Arc, ordinal: usize) -> Self { - Self { child, ordinal } - } - - fn list_field(&self, input_schema: &Schema) -> DataFusionResult { - match self.child.data_type(input_schema)? { - DataType::List(field) | DataType::LargeList(field) => Ok(field), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in GetArrayStructFields: {:?}", - data_type - ))), - } - } - - fn child_field(&self, input_schema: &Schema) -> DataFusionResult { - match self.list_field(input_schema)?.data_type() { - DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in GetArrayStructFields: {:?}", - data_type - ))), - } - } -} - -impl PhysicalExpr for GetArrayStructFields { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - let struct_field = self.child_field(input_schema)?; - match self.child.data_type(input_schema)? { - DataType::List(_) => Ok(DataType::List(struct_field)), - DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in GetArrayStructFields: {:?}", - data_type - ))), - } - } - - fn nullable(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.list_field(input_schema)?.is_nullable() - || self.child_field(input_schema)?.is_nullable()) - } - - fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; - - match child_value.data_type() { - DataType::List(_) => { - let list_array = as_list_array(&child_value)?; - - get_array_struct_fields(list_array, self.ordinal) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&child_value)?; - - get_array_struct_fields(list_array, self.ordinal) - } - data_type => Err(DataFusionError::Internal(format!( - "Unexpected child type for ListExtract: {:?}", - data_type - ))), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - match children.len() { - 1 => Ok(Arc::new(GetArrayStructFields::new( - Arc::clone(&children[0]), - self.ordinal, - ))), - _ => internal_err!("GetArrayStructFields should have exactly one child"), - } - } -} - -fn get_array_struct_fields( - list_array: &GenericListArray, - ordinal: usize, -) -> DataFusionResult { - let values = list_array - .values() - .as_any() - .downcast_ref::() - .expect("A struct is expected"); - - let column = Arc::clone(values.column(ordinal)); - let field = Arc::clone(&values.fields()[ordinal]); - - let offsets = list_array.offsets(); - let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); - - Ok(ColumnarValue::Array(Arc::new(array))) -} - -impl Display for GetArrayStructFields { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "GetArrayStructFields [child: {:?}, ordinal: {:?}]", - self.child, self.ordinal - ) - } -} - #[derive(Debug, Eq)] pub struct ArrayInsert { src_array_expr: Arc, @@ -687,51 +319,13 @@ impl Display for ArrayInsert { #[cfg(test)] mod test { - use crate::list::{array_insert, list_extract, zero_based_index}; - + use super::*; use arrow::datatypes::Int32Type; use arrow_array::{Array, ArrayRef, Int32Array, ListArray}; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::Result; use datafusion_expr::ColumnarValue; use std::sync::Arc; - #[test] - fn test_list_extract_default_value() -> Result<()> { - let list = ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1)]), - None, - Some(vec![]), - ]); - let indices = Int32Array::from(vec![0, 0, 0]); - - let null_default = ScalarValue::Int32(None); - - let ColumnarValue::Array(result) = - list_extract(&list, &indices, &null_default, false, zero_based_index)? - else { - unreachable!() - }; - - assert_eq!( - &result.to_data(), - &Int32Array::from(vec![Some(1), None, None]).to_data() - ); - - let zero_default = ScalarValue::Int32(Some(0)); - - let ColumnarValue::Array(result) = - list_extract(&list, &indices, &zero_default, false, zero_based_index)? - else { - unreachable!() - }; - - assert_eq!( - &result.to_data(), - &Int32Array::from(vec![Some(1), None, Some(0)]).to_data() - ); - Ok(()) - } - #[test] fn test_array_insert() -> Result<()> { // Test inserting an item into a list array diff --git a/native/spark-expr/src/array_funcs/get_array_struct_fields.rs b/native/spark-expr/src/array_funcs/get_array_struct_fields.rs new file mode 100644 index 0000000000..8b1633649c --- /dev/null +++ b/native/spark-expr/src/array_funcs/get_array_struct_fields.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, GenericListArray, OffsetSizeTrait, StructArray}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct GetArrayStructFields { + child: Arc, + ordinal: usize, +} + +impl Hash for GetArrayStructFields { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + } +} +impl PartialEq for GetArrayStructFields { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) + } +} + +impl GetArrayStructFields { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn list_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.list_field(input_schema)?.data_type() { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetArrayStructFields { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + let struct_field = self.child_field(input_schema)?; + match self.child.data_type(input_schema)? { + DataType::List(_) => Ok(DataType::List(struct_field)), + DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.list_field(input_schema)?.is_nullable() + || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(GetArrayStructFields::new( + Arc::clone(&children[0]), + self.ordinal, + ))), + _ => internal_err!("GetArrayStructFields should have exactly one child"), + } + } +} + +fn get_array_struct_fields( + list_array: &GenericListArray, + ordinal: usize, +) -> DataFusionResult { + let values = list_array + .values() + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + let column = Arc::clone(values.column(ordinal)); + let field = Arc::clone(&values.fields()[ordinal]); + + let offsets = list_array.offsets(); + let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::new(array))) +} + +impl Display for GetArrayStructFields { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetArrayStructFields [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} diff --git a/native/spark-expr/src/array_funcs/list_extract.rs b/native/spark-expr/src/array_funcs/list_extract.rs new file mode 100644 index 0000000000..c0f2291d9f --- /dev/null +++ b/native/spark-expr/src/array_funcs/list_extract.rs @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_int32_array, as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct ListExtract { + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, +} + +impl Hash for ListExtract { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + self.default_value.hash(state); + self.one_based.hash(state); + self.fail_on_error.hash(state); + } +} +impl PartialEq for ListExtract { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.ordinal.eq(&other.ordinal) + && self.default_value.eq(&other.default_value) + && self.one_based.eq(&other.one_based) + && self.fail_on_error.eq(&other.fail_on_error) + } +} + +impl ListExtract { + pub fn new( + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, + ) -> Self { + Self { + child, + ordinal, + default_value, + one_based, + fail_on_error, + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ListExtract: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for ListExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + // Only non-nullable if fail_on_error is enabled and the element is non-nullable + Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; + + let default_value = self + .default_value + .as_ref() + .map(|d| { + d.evaluate(batch).map(|value| match value { + ColumnarValue::Scalar(scalar) + if !scalar.data_type().equals_datatype(child_value.data_type()) => + { + scalar.cast_to(child_value.data_type()) + } + ColumnarValue::Scalar(scalar) => Ok(scalar), + v => Err(DataFusionError::Execution(format!( + "Expected scalar default value for ListExtract, got {:?}", + v + ))), + }) + }) + .transpose()? + .unwrap_or(self.data_type(&batch.schema())?.try_into())?; + + let adjust_index = if self.one_based { + one_based_index + } else { + zero_based_index + }; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + list_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + list_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child, &self.ordinal] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 2 => Ok(Arc::new(ListExtract::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.default_value.clone(), + self.one_based, + self.fail_on_error, + ))), + _ => internal_err!("ListExtract should have exactly two children"), + } + } +} + +fn one_based_index(index: i32, len: usize) -> DataFusionResult> { + if index == 0 { + return Err(DataFusionError::Execution( + "Invalid index of 0 for one-based ListExtract".to_string(), + )); + } + + let abs_index = index.abs().as_usize(); + if abs_index <= len { + if index > 0 { + Ok(Some(abs_index - 1)) + } else { + Ok(Some(len - abs_index)) + } + } else { + Ok(None) + } +} + +fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { + if index < 0 { + Ok(None) + } else { + let positive_index = index.as_usize(); + if positive_index < len { + Ok(Some(positive_index)) + } else { + Ok(None) + } + } +} + +fn list_extract( + list_array: &GenericListArray, + index_array: &Int32Array, + default_value: &ScalarValue, + fail_on_error: bool, + adjust_index: impl Fn(i32, usize) -> DataFusionResult>, +) -> DataFusionResult { + let values = list_array.values(); + let offsets = list_array.offsets(); + + let data = values.to_data(); + + let default_data = default_value.to_array()?.to_data(); + + let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); + + for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() { + let start = offset_window[0].as_usize(); + let len = offset_window[1].as_usize() - start; + + if let Some(i) = adjust_index(*index, len)? { + mutable.extend(0, start + i, start + i + 1); + } else if list_array.is_null(row) { + mutable.extend_nulls(1); + } else if fail_on_error { + return Err(DataFusionError::Execution( + "Index out of bounds for array".to_string(), + )); + } else { + mutable.extend(1, 0, 1); + } + } + + let data = mutable.freeze(); + Ok(ColumnarValue::Array(arrow::array::make_array(data))) +} + +impl Display for ListExtract { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", + self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::datatypes::Int32Type; + use arrow_array::{Array, Int32Array, ListArray}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_list_extract_default_value() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1)]), + None, + Some(vec![]), + ]); + let indices = Int32Array::from(vec![0, 0, 0]); + + let null_default = ScalarValue::Int32(None); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &null_default, false, zero_based_index)? + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, None]).to_data() + ); + + let zero_default = ScalarValue::Int32(Some(0)); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &zero_default, false, zero_based_index)? + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, Some(0)]).to_data() + ); + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs new file mode 100644 index 0000000000..0a215f96cf --- /dev/null +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod array_insert; +mod get_array_struct_fields; +mod list_extract; + +pub use array_insert::ArrayInsert; +pub use get_array_struct_fields::GetArrayStructFields; +pub use list_extract::ListExtract; diff --git a/native/spark-expr/src/bitwise_not.rs b/native/spark-expr/src/bitwise_funcs/bitwise_not.rs similarity index 100% rename from native/spark-expr/src/bitwise_not.rs rename to native/spark-expr/src/bitwise_funcs/bitwise_not.rs diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs new file mode 100644 index 0000000000..9c26363319 --- /dev/null +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bitwise_not; + +pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; diff --git a/native/spark-expr/src/if_expr.rs b/native/spark-expr/src/conditional_funcs/if_expr.rs similarity index 100% rename from native/spark-expr/src/if_expr.rs rename to native/spark-expr/src/conditional_funcs/if_expr.rs diff --git a/native/spark-expr/src/conditional_funcs/mod.rs b/native/spark-expr/src/conditional_funcs/mod.rs new file mode 100644 index 0000000000..70c459ef7c --- /dev/null +++ b/native/spark-expr/src/conditional_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod if_expr; + +pub use if_expr::IfExpr; diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs similarity index 100% rename from native/spark-expr/src/cast.rs rename to native/spark-expr/src/conversion_funcs/cast.rs diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs new file mode 100644 index 0000000000..f2c6f7ca36 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod cast; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 22bec87ee6..14982264d1 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -19,17 +19,12 @@ // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] -mod cast; mod error; -mod if_expr; -mod bitwise_not; -pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; mod checkoverflow; pub use checkoverflow::CheckOverflow; mod kernels; -mod list; pub mod scalar_funcs; mod schema_adapter; mod static_invoke; @@ -52,6 +47,8 @@ mod predicate_funcs; pub use predicate_funcs::{spark_isnan, RLike}; mod agg_funcs; +mod array_funcs; +mod bitwise_funcs; mod comet_scalar_funcs; pub mod hash_funcs; @@ -63,13 +60,19 @@ pub use agg_funcs::*; pub use crate::{CreateNamedStruct, GetStructField}; pub use crate::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; pub use cast::{spark_cast, Cast, SparkCastOptions}; +mod conditional_funcs; +mod conversion_funcs; + +pub use array_funcs::*; +pub use bitwise_funcs::*; +pub use conditional_funcs::*; +pub use conversion_funcs::*; + pub use comet_scalar_funcs::create_comet_physical_fun; pub use datetime_funcs::*; pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; -pub use if_expr::IfExpr; pub use json_funcs::ToJson; -pub use list::{ArrayInsert, GetArrayStructFields, ListExtract}; pub use string_funcs::*; pub use struct_funcs::*;