diff --git a/native/core/src/execution/shuffle/shuffle_writer.rs b/native/core/src/execution/shuffle/shuffle_writer.rs index f3fa685b88..38748e9243 100644 --- a/native/core/src/execution/shuffle/shuffle_writer.rs +++ b/native/core/src/execution/shuffle/shuffle_writer.rs @@ -49,7 +49,7 @@ use datafusion::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }, }; -use datafusion_comet_spark_expr::spark_hash::create_murmur3_hashes; +use datafusion_comet_spark_expr::hash_funcs::murmur3::create_murmur3_hashes; use datafusion_physical_expr::EquivalenceProperties; use futures::executor::block_on; use futures::{lock::Mutex, Stream, StreamExt, TryFutureExt, TryStreamExt}; diff --git a/native/core/src/execution/util/spark_bloom_filter.rs b/native/core/src/execution/util/spark_bloom_filter.rs index 61245757cf..7eb04f7b0f 100644 --- a/native/core/src/execution/util/spark_bloom_filter.rs +++ b/native/core/src/execution/util/spark_bloom_filter.rs @@ -19,7 +19,7 @@ use crate::execution::util::spark_bit_array; use crate::execution::util::spark_bit_array::SparkBitArray; use arrow_array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; use arrow_buffer::ToByteSlice; -use datafusion_comet_spark_expr::spark_hash::spark_compatible_murmur3_hash; +use datafusion_comet_spark_expr::hash_funcs::murmur3::spark_compatible_murmur3_hash; use std::cmp; const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; diff --git a/native/spark-expr/README.md b/native/spark-expr/README.md index a7ee753632..0db452c890 100644 --- a/native/spark-expr/README.md +++ b/native/spark-expr/README.md @@ -20,4 +20,37 @@ under the License. # datafusion-comet-spark-expr: Spark-compatible Expressions This crate provides Apache Spark-compatible expressions for use with DataFusion and is maintained as part of the -[Apache DataFusion Comet](https://github.com/apache/datafusion-comet/) subproject. \ No newline at end of file +[Apache DataFusion Comet](https://github.com/apache/datafusion-comet/) subproject. + +## Expression location + +The files are aimed to be organized in the same way Spark group its expressions. +You can see the grouping in [Spark Docs](https://spark.apache.org/docs/3.5.3/sql-ref-functions-builtin.html) +or in Spark source code marked with `ExpressionDescription` annotation. + +For example, for the following expression (taken from Spark source code): +```scala +@ExpressionDescription( + usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.", + examples = """ + Examples: + > SELECT _FUNC_('Spark', array(123), 2); + -1321691492 + """, + since = "2.0.0", + group = "hash_funcs") +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { + // ... +} +``` + +the native implementation will be in the `hash_funcs/murmur3.rs` file. + +Some expressions are not under a specific group like the `UnscaledValue` expression (taken from Spark source code): +```scala +case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant { + // ... +} +``` + +In that case we will do our best to find a suitable group for it (for the example above, it would be under `math_funcs/internal/unscaled_value.rs`). diff --git a/native/spark-expr/benches/decimal_div.rs b/native/spark-expr/benches/decimal_div.rs index 89f06e5053..ad527fecba 100644 --- a/native/spark-expr/benches/decimal_div.rs +++ b/native/spark-expr/benches/decimal_div.rs @@ -19,7 +19,7 @@ use arrow::compute::cast; use arrow_array::builder::Decimal128Builder; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_comet_spark_expr::scalar_funcs::spark_decimal_div; +use datafusion_comet_spark_expr::spark_decimal_div; use datafusion_expr::ColumnarValue; use std::sync::Arc; diff --git a/native/spark-expr/src/avg.rs b/native/spark-expr/src/agg_funcs/avg.rs similarity index 100% rename from native/spark-expr/src/avg.rs rename to native/spark-expr/src/agg_funcs/avg.rs diff --git a/native/spark-expr/src/avg_decimal.rs b/native/spark-expr/src/agg_funcs/avg_decimal.rs similarity index 100% rename from native/spark-expr/src/avg_decimal.rs rename to native/spark-expr/src/agg_funcs/avg_decimal.rs diff --git a/native/spark-expr/src/correlation.rs b/native/spark-expr/src/agg_funcs/correlation.rs similarity index 98% rename from native/spark-expr/src/correlation.rs rename to native/spark-expr/src/agg_funcs/correlation.rs index e4ddab95de..5d6f9e0b43 100644 --- a/native/spark-expr/src/correlation.rs +++ b/native/spark-expr/src/agg_funcs/correlation.rs @@ -19,8 +19,8 @@ use arrow::compute::{and, filter, is_not_null}; use std::{any::Any, sync::Arc}; -use crate::covariance::CovarianceAccumulator; -use crate::stddev::StddevAccumulator; +use crate::agg_funcs::covariance::CovarianceAccumulator; +use crate::agg_funcs::stddev::StddevAccumulator; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, diff --git a/native/spark-expr/src/covariance.rs b/native/spark-expr/src/agg_funcs/covariance.rs similarity index 100% rename from native/spark-expr/src/covariance.rs rename to native/spark-expr/src/agg_funcs/covariance.rs diff --git a/native/spark-expr/src/agg_funcs/mod.rs b/native/spark-expr/src/agg_funcs/mod.rs new file mode 100644 index 0000000000..252da78890 --- /dev/null +++ b/native/spark-expr/src/agg_funcs/mod.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod avg; +mod avg_decimal; +mod correlation; +mod covariance; +mod stddev; +mod sum_decimal; +mod variance; + +pub use avg::Avg; +pub use avg_decimal::AvgDecimal; +pub use correlation::Correlation; +pub use covariance::Covariance; +pub use stddev::Stddev; +pub use sum_decimal::SumDecimal; +pub use variance::Variance; diff --git a/native/spark-expr/src/stddev.rs b/native/spark-expr/src/agg_funcs/stddev.rs similarity index 99% rename from native/spark-expr/src/stddev.rs rename to native/spark-expr/src/agg_funcs/stddev.rs index 1ec5ffb69a..39dffa1c8e 100644 --- a/native/spark-expr/src/stddev.rs +++ b/native/spark-expr/src/agg_funcs/stddev.rs @@ -17,7 +17,7 @@ use std::{any::Any, sync::Arc}; -use crate::variance::VarianceAccumulator; +use crate::agg_funcs::variance::VarianceAccumulator; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, diff --git a/native/spark-expr/src/sum_decimal.rs b/native/spark-expr/src/agg_funcs/sum_decimal.rs similarity index 100% rename from native/spark-expr/src/sum_decimal.rs rename to native/spark-expr/src/agg_funcs/sum_decimal.rs diff --git a/native/spark-expr/src/variance.rs b/native/spark-expr/src/agg_funcs/variance.rs similarity index 100% rename from native/spark-expr/src/variance.rs rename to native/spark-expr/src/agg_funcs/variance.rs diff --git a/native/spark-expr/src/list.rs b/native/spark-expr/src/array_funcs/array_insert.rs similarity index 54% rename from native/spark-expr/src/list.rs rename to native/spark-expr/src/array_funcs/array_insert.rs index fc31b11a0b..08fb789056 100644 --- a/native/spark-expr/src/list.rs +++ b/native/spark-expr/src/array_funcs/array_insert.rs @@ -21,14 +21,12 @@ use arrow::{ datatypes::ArrowNativeType, record_batch::RecordBatch, }; -use arrow_array::{ - make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, StructArray, -}; -use arrow_schema::{DataType, Field, FieldRef, Schema}; +use arrow_array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{ - cast::{as_int32_array, as_large_list_array, as_list_array}, - internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, + cast::{as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, }; use datafusion_physical_expr::PhysicalExpr; use std::hash::Hash; @@ -43,372 +41,6 @@ use std::{ // https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632; -#[derive(Debug, Eq)] -pub struct ListExtract { - child: Arc, - ordinal: Arc, - default_value: Option>, - one_based: bool, - fail_on_error: bool, -} - -impl Hash for ListExtract { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.ordinal.hash(state); - self.default_value.hash(state); - self.one_based.hash(state); - self.fail_on_error.hash(state); - } -} -impl PartialEq for ListExtract { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - && self.ordinal.eq(&other.ordinal) - && self.default_value.eq(&other.default_value) - && self.one_based.eq(&other.one_based) - && self.fail_on_error.eq(&other.fail_on_error) - } -} - -impl ListExtract { - pub fn new( - child: Arc, - ordinal: Arc, - default_value: Option>, - one_based: bool, - fail_on_error: bool, - ) -> Self { - Self { - child, - ordinal, - default_value, - one_based, - fail_on_error, - } - } - - fn child_field(&self, input_schema: &Schema) -> DataFusionResult { - match self.child.data_type(input_schema)? { - DataType::List(field) | DataType::LargeList(field) => Ok(field), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in ListExtract: {:?}", - data_type - ))), - } - } -} - -impl PhysicalExpr for ListExtract { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.child_field(input_schema)?.data_type().clone()) - } - - fn nullable(&self, input_schema: &Schema) -> DataFusionResult { - // Only non-nullable if fail_on_error is enabled and the element is non-nullable - Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable()) - } - - fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; - let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; - - let default_value = self - .default_value - .as_ref() - .map(|d| { - d.evaluate(batch).map(|value| match value { - ColumnarValue::Scalar(scalar) - if !scalar.data_type().equals_datatype(child_value.data_type()) => - { - scalar.cast_to(child_value.data_type()) - } - ColumnarValue::Scalar(scalar) => Ok(scalar), - v => Err(DataFusionError::Execution(format!( - "Expected scalar default value for ListExtract, got {:?}", - v - ))), - }) - }) - .transpose()? - .unwrap_or(self.data_type(&batch.schema())?.try_into())?; - - let adjust_index = if self.one_based { - one_based_index - } else { - zero_based_index - }; - - match child_value.data_type() { - DataType::List(_) => { - let list_array = as_list_array(&child_value)?; - let index_array = as_int32_array(&ordinal_value)?; - - list_extract( - list_array, - index_array, - &default_value, - self.fail_on_error, - adjust_index, - ) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&child_value)?; - let index_array = as_int32_array(&ordinal_value)?; - - list_extract( - list_array, - index_array, - &default_value, - self.fail_on_error, - adjust_index, - ) - } - data_type => Err(DataFusionError::Internal(format!( - "Unexpected child type for ListExtract: {:?}", - data_type - ))), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child, &self.ordinal] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - match children.len() { - 2 => Ok(Arc::new(ListExtract::new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - self.default_value.clone(), - self.one_based, - self.fail_on_error, - ))), - _ => internal_err!("ListExtract should have exactly two children"), - } - } -} - -fn one_based_index(index: i32, len: usize) -> DataFusionResult> { - if index == 0 { - return Err(DataFusionError::Execution( - "Invalid index of 0 for one-based ListExtract".to_string(), - )); - } - - let abs_index = index.abs().as_usize(); - if abs_index <= len { - if index > 0 { - Ok(Some(abs_index - 1)) - } else { - Ok(Some(len - abs_index)) - } - } else { - Ok(None) - } -} - -fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { - if index < 0 { - Ok(None) - } else { - let positive_index = index.as_usize(); - if positive_index < len { - Ok(Some(positive_index)) - } else { - Ok(None) - } - } -} - -fn list_extract( - list_array: &GenericListArray, - index_array: &Int32Array, - default_value: &ScalarValue, - fail_on_error: bool, - adjust_index: impl Fn(i32, usize) -> DataFusionResult>, -) -> DataFusionResult { - let values = list_array.values(); - let offsets = list_array.offsets(); - - let data = values.to_data(); - - let default_data = default_value.to_array()?.to_data(); - - let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); - - for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() { - let start = offset_window[0].as_usize(); - let len = offset_window[1].as_usize() - start; - - if let Some(i) = adjust_index(*index, len)? { - mutable.extend(0, start + i, start + i + 1); - } else if list_array.is_null(row) { - mutable.extend_nulls(1); - } else if fail_on_error { - return Err(DataFusionError::Execution( - "Index out of bounds for array".to_string(), - )); - } else { - mutable.extend(1, 0, 1); - } - } - - let data = mutable.freeze(); - Ok(ColumnarValue::Array(arrow::array::make_array(data))) -} - -impl Display for ListExtract { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", - self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error - ) - } -} - -#[derive(Debug, Eq)] -pub struct GetArrayStructFields { - child: Arc, - ordinal: usize, -} - -impl Hash for GetArrayStructFields { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.ordinal.hash(state); - } -} -impl PartialEq for GetArrayStructFields { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) - } -} - -impl GetArrayStructFields { - pub fn new(child: Arc, ordinal: usize) -> Self { - Self { child, ordinal } - } - - fn list_field(&self, input_schema: &Schema) -> DataFusionResult { - match self.child.data_type(input_schema)? { - DataType::List(field) | DataType::LargeList(field) => Ok(field), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in GetArrayStructFields: {:?}", - data_type - ))), - } - } - - fn child_field(&self, input_schema: &Schema) -> DataFusionResult { - match self.list_field(input_schema)?.data_type() { - DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in GetArrayStructFields: {:?}", - data_type - ))), - } - } -} - -impl PhysicalExpr for GetArrayStructFields { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - let struct_field = self.child_field(input_schema)?; - match self.child.data_type(input_schema)? { - DataType::List(_) => Ok(DataType::List(struct_field)), - DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), - data_type => Err(DataFusionError::Internal(format!( - "Unexpected data type in GetArrayStructFields: {:?}", - data_type - ))), - } - } - - fn nullable(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.list_field(input_schema)?.is_nullable() - || self.child_field(input_schema)?.is_nullable()) - } - - fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; - - match child_value.data_type() { - DataType::List(_) => { - let list_array = as_list_array(&child_value)?; - - get_array_struct_fields(list_array, self.ordinal) - } - DataType::LargeList(_) => { - let list_array = as_large_list_array(&child_value)?; - - get_array_struct_fields(list_array, self.ordinal) - } - data_type => Err(DataFusionError::Internal(format!( - "Unexpected child type for ListExtract: {:?}", - data_type - ))), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - match children.len() { - 1 => Ok(Arc::new(GetArrayStructFields::new( - Arc::clone(&children[0]), - self.ordinal, - ))), - _ => internal_err!("GetArrayStructFields should have exactly one child"), - } - } -} - -fn get_array_struct_fields( - list_array: &GenericListArray, - ordinal: usize, -) -> DataFusionResult { - let values = list_array - .values() - .as_any() - .downcast_ref::() - .expect("A struct is expected"); - - let column = Arc::clone(values.column(ordinal)); - let field = Arc::clone(&values.fields()[ordinal]); - - let offsets = list_array.offsets(); - let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); - - Ok(ColumnarValue::Array(Arc::new(array))) -} - -impl Display for GetArrayStructFields { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "GetArrayStructFields [child: {:?}, ordinal: {:?}]", - self.child, self.ordinal - ) - } -} - #[derive(Debug, Eq)] pub struct ArrayInsert { src_array_expr: Arc, @@ -687,51 +319,13 @@ impl Display for ArrayInsert { #[cfg(test)] mod test { - use crate::list::{array_insert, list_extract, zero_based_index}; - + use super::*; use arrow::datatypes::Int32Type; use arrow_array::{Array, ArrayRef, Int32Array, ListArray}; - use datafusion_common::{Result, ScalarValue}; + use datafusion_common::Result; use datafusion_expr::ColumnarValue; use std::sync::Arc; - #[test] - fn test_list_extract_default_value() -> Result<()> { - let list = ListArray::from_iter_primitive::(vec![ - Some(vec![Some(1)]), - None, - Some(vec![]), - ]); - let indices = Int32Array::from(vec![0, 0, 0]); - - let null_default = ScalarValue::Int32(None); - - let ColumnarValue::Array(result) = - list_extract(&list, &indices, &null_default, false, zero_based_index)? - else { - unreachable!() - }; - - assert_eq!( - &result.to_data(), - &Int32Array::from(vec![Some(1), None, None]).to_data() - ); - - let zero_default = ScalarValue::Int32(Some(0)); - - let ColumnarValue::Array(result) = - list_extract(&list, &indices, &zero_default, false, zero_based_index)? - else { - unreachable!() - }; - - assert_eq!( - &result.to_data(), - &Int32Array::from(vec![Some(1), None, Some(0)]).to_data() - ); - Ok(()) - } - #[test] fn test_array_insert() -> Result<()> { // Test inserting an item into a list array diff --git a/native/spark-expr/src/array_funcs/get_array_struct_fields.rs b/native/spark-expr/src/array_funcs/get_array_struct_fields.rs new file mode 100644 index 0000000000..8b1633649c --- /dev/null +++ b/native/spark-expr/src/array_funcs/get_array_struct_fields.rs @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, GenericListArray, OffsetSizeTrait, StructArray}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct GetArrayStructFields { + child: Arc, + ordinal: usize, +} + +impl Hash for GetArrayStructFields { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + } +} +impl PartialEq for GetArrayStructFields { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) + } +} + +impl GetArrayStructFields { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn list_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.list_field(input_schema)?.data_type() { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetArrayStructFields { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + let struct_field = self.child_field(input_schema)?; + match self.child.data_type(input_schema)? { + DataType::List(_) => Ok(DataType::List(struct_field)), + DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.list_field(input_schema)?.is_nullable() + || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(GetArrayStructFields::new( + Arc::clone(&children[0]), + self.ordinal, + ))), + _ => internal_err!("GetArrayStructFields should have exactly one child"), + } + } +} + +fn get_array_struct_fields( + list_array: &GenericListArray, + ordinal: usize, +) -> DataFusionResult { + let values = list_array + .values() + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + let column = Arc::clone(values.column(ordinal)); + let field = Arc::clone(&values.fields()[ordinal]); + + let offsets = list_array.offsets(); + let array = GenericListArray::new(field, offsets.clone(), column, list_array.nulls().cloned()); + + Ok(ColumnarValue::Array(Arc::new(array))) +} + +impl Display for GetArrayStructFields { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetArrayStructFields [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} diff --git a/native/spark-expr/src/array_funcs/list_extract.rs b/native/spark-expr/src/array_funcs/list_extract.rs new file mode 100644 index 0000000000..c0f2291d9f --- /dev/null +++ b/native/spark-expr/src/array_funcs/list_extract.rs @@ -0,0 +1,310 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch}; +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_int32_array, as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct ListExtract { + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, +} + +impl Hash for ListExtract { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + self.default_value.hash(state); + self.one_based.hash(state); + self.fail_on_error.hash(state); + } +} +impl PartialEq for ListExtract { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.ordinal.eq(&other.ordinal) + && self.default_value.eq(&other.default_value) + && self.one_based.eq(&other.one_based) + && self.fail_on_error.eq(&other.fail_on_error) + } +} + +impl ListExtract { + pub fn new( + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, + ) -> Self { + Self { + child, + ordinal, + default_value, + one_based, + fail_on_error, + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ListExtract: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for ListExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + // Only non-nullable if fail_on_error is enabled and the element is non-nullable + Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; + + let default_value = self + .default_value + .as_ref() + .map(|d| { + d.evaluate(batch).map(|value| match value { + ColumnarValue::Scalar(scalar) + if !scalar.data_type().equals_datatype(child_value.data_type()) => + { + scalar.cast_to(child_value.data_type()) + } + ColumnarValue::Scalar(scalar) => Ok(scalar), + v => Err(DataFusionError::Execution(format!( + "Expected scalar default value for ListExtract, got {:?}", + v + ))), + }) + }) + .transpose()? + .unwrap_or(self.data_type(&batch.schema())?.try_into())?; + + let adjust_index = if self.one_based { + one_based_index + } else { + zero_based_index + }; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + list_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + list_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child, &self.ordinal] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 2 => Ok(Arc::new(ListExtract::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.default_value.clone(), + self.one_based, + self.fail_on_error, + ))), + _ => internal_err!("ListExtract should have exactly two children"), + } + } +} + +fn one_based_index(index: i32, len: usize) -> DataFusionResult> { + if index == 0 { + return Err(DataFusionError::Execution( + "Invalid index of 0 for one-based ListExtract".to_string(), + )); + } + + let abs_index = index.abs().as_usize(); + if abs_index <= len { + if index > 0 { + Ok(Some(abs_index - 1)) + } else { + Ok(Some(len - abs_index)) + } + } else { + Ok(None) + } +} + +fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { + if index < 0 { + Ok(None) + } else { + let positive_index = index.as_usize(); + if positive_index < len { + Ok(Some(positive_index)) + } else { + Ok(None) + } + } +} + +fn list_extract( + list_array: &GenericListArray, + index_array: &Int32Array, + default_value: &ScalarValue, + fail_on_error: bool, + adjust_index: impl Fn(i32, usize) -> DataFusionResult>, +) -> DataFusionResult { + let values = list_array.values(); + let offsets = list_array.offsets(); + + let data = values.to_data(); + + let default_data = default_value.to_array()?.to_data(); + + let mut mutable = MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); + + for (row, (offset_window, index)) in offsets.windows(2).zip(index_array.values()).enumerate() { + let start = offset_window[0].as_usize(); + let len = offset_window[1].as_usize() - start; + + if let Some(i) = adjust_index(*index, len)? { + mutable.extend(0, start + i, start + i + 1); + } else if list_array.is_null(row) { + mutable.extend_nulls(1); + } else if fail_on_error { + return Err(DataFusionError::Execution( + "Index out of bounds for array".to_string(), + )); + } else { + mutable.extend(1, 0, 1); + } + } + + let data = mutable.freeze(); + Ok(ColumnarValue::Array(arrow::array::make_array(data))) +} + +impl Display for ListExtract { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", + self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::datatypes::Int32Type; + use arrow_array::{Array, Int32Array, ListArray}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_list_extract_default_value() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1)]), + None, + Some(vec![]), + ]); + let indices = Int32Array::from(vec![0, 0, 0]); + + let null_default = ScalarValue::Int32(None); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &null_default, false, zero_based_index)? + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, None]).to_data() + ); + + let zero_default = ScalarValue::Int32(Some(0)); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &zero_default, false, zero_based_index)? + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, Some(0)]).to_data() + ); + Ok(()) + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs new file mode 100644 index 0000000000..0a215f96cf --- /dev/null +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod array_insert; +mod get_array_struct_fields; +mod list_extract; + +pub use array_insert::ArrayInsert; +pub use get_array_struct_fields::GetArrayStructFields; +pub use list_extract::ListExtract; diff --git a/native/spark-expr/src/bitwise_not.rs b/native/spark-expr/src/bitwise_funcs/bitwise_not.rs similarity index 100% rename from native/spark-expr/src/bitwise_not.rs rename to native/spark-expr/src/bitwise_funcs/bitwise_not.rs diff --git a/native/spark-expr/src/bitwise_funcs/mod.rs b/native/spark-expr/src/bitwise_funcs/mod.rs new file mode 100644 index 0000000000..9c26363319 --- /dev/null +++ b/native/spark-expr/src/bitwise_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bitwise_not; + +pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 71ff0e9dcc..a6fb13d008 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::scalar_funcs::hash_expressions::{ - spark_sha224, spark_sha256, spark_sha384, spark_sha512, +use crate::datetime_funcs::{spark_date_add, spark_date_sub}; +use crate::hash_funcs::{ + spark_murmur3_hash, spark_sha224, spark_sha256, spark_sha384, spark_sha512, spark_xxhash64, }; -use crate::scalar_funcs::{ - spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex, - spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round, - spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc, +use crate::math_funcs::internal::{spark_make_decimal, spark_unscaled_value}; +use crate::math_funcs::{ + spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_round, spark_unhex, }; +use crate::predicate_funcs::spark_isnan; +use crate::static_invoke::spark_read_side_padding; +use crate::string_funcs::SparkChrFunc; use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::registry::FunctionRegistry; diff --git a/native/spark-expr/src/if_expr.rs b/native/spark-expr/src/conditional_funcs/if_expr.rs similarity index 100% rename from native/spark-expr/src/if_expr.rs rename to native/spark-expr/src/conditional_funcs/if_expr.rs diff --git a/native/spark-expr/src/conditional_funcs/mod.rs b/native/spark-expr/src/conditional_funcs/mod.rs new file mode 100644 index 0000000000..70c459ef7c --- /dev/null +++ b/native/spark-expr/src/conditional_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod if_expr; + +pub use if_expr::IfExpr; diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs similarity index 100% rename from native/spark-expr/src/cast.rs rename to native/spark-expr/src/conversion_funcs/cast.rs diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs new file mode 100644 index 0000000000..4c14434f56 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod cast; + +pub use cast::{spark_cast, Cast, SparkCastOptions}; diff --git a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs new file mode 100644 index 0000000000..cc4da9af70 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::kernels::numeric::{add, sub}; +use arrow::datatypes::IntervalDayTime; +use arrow_array::builder::IntervalDayTimeBuilder; +use arrow_array::types::{Int16Type, Int32Type, Int8Type}; +use arrow_array::{Array, Datum}; +use arrow_schema::{ArrowError, DataType}; +use datafusion::physical_expr_common::datum; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::sync::Arc; + +macro_rules! scalar_date_arithmetic { + ($start:expr, $days:expr, $op:expr) => {{ + let interval = IntervalDayTime::new(*$days as i32, 0); + let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); + datum::apply($start, &interval_cv, $op) + }}; +} +macro_rules! array_date_arithmetic { + ($days:expr, $interval_builder:expr, $intType:ty) => {{ + for day in $days.as_primitive::<$intType>().into_iter() { + if let Some(non_null_day) = day { + $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); + } else { + $interval_builder.append_null(); + } + } + }}; +} + +/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second +/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the +/// second argument and use DataFusion's interface to apply Arrow's operators. +fn spark_date_arithmetic( + args: &[ColumnarValue], + op: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + let start = &args[0]; + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Array(days) => { + let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); + match days.data_type() { + DataType::Int8 => { + array_date_arithmetic!(days, interval_builder, Int8Type) + } + DataType::Int16 => { + array_date_arithmetic!(days, interval_builder, Int16Type) + } + DataType::Int32 => { + array_date_arithmetic!(days, interval_builder, Int32Type) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))) + } + } + let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); + datum::apply(start, &interval_cv, op) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))), + } +} + +pub fn spark_date_add(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, add) +} + +pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, sub) +} diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs new file mode 100644 index 0000000000..5c044945d0 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn}; + +#[derive(Debug, Eq)] +pub struct DateTruncExpr { + /// An array with DataType::Date32 + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc + format: Arc, +} + +impl Hash for DateTruncExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.format.hash(state); + } +} +impl PartialEq for DateTruncExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.format.eq(&other.format) + } +} + +impl DateTruncExpr { + pub fn new(child: Arc, format: Arc) -> Self { + DateTruncExpr { child, format } + } +} + +impl Display for DateTruncExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DateTrunc [child:{}, format: {}]", + self.child, self.format + ) + } +} + +impl PhysicalExpr for DateTruncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let date = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + match (date, format) { + (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let result = date_trunc_dyn(&date, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { + let result = date_trunc_array_fmt_dyn(&date, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ + (PrimitiveArray, StringArray)".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(DateTruncExpr::new( + Arc::clone(&children[0]), + Arc::clone(&self.format), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/hour.rs b/native/spark-expr/src/datetime_funcs/hour.rs new file mode 100644 index 0000000000..faf9529a51 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/hour.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct HourExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for HourExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for HourExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl HourExpr { + pub fn new(child: Arc, timezone: String) -> Self { + HourExpr { child, timezone } + } +} + +impl Display for HourExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Hour [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for HourExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Hour)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Hour(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(HourExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/minute.rs b/native/spark-expr/src/datetime_funcs/minute.rs new file mode 100644 index 0000000000..b7facc1673 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/minute.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct MinuteExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for MinuteExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for MinuteExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl MinuteExpr { + pub fn new(child: Arc, timezone: String) -> Self { + MinuteExpr { child, timezone } + } +} + +impl Display for MinuteExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Minute [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for MinuteExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Minute)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Minute(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(MinuteExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs new file mode 100644 index 0000000000..1f4d427282 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod date_arithmetic; +mod date_trunc; +mod hour; +mod minute; +mod second; +mod timestamp_trunc; + +pub use date_arithmetic::{spark_date_add, spark_date_sub}; +pub use date_trunc::DateTruncExpr; +pub use hour::HourExpr; +pub use minute::MinuteExpr; +pub use second::SecondExpr; +pub use timestamp_trunc::TimestampTruncExpr; diff --git a/native/spark-expr/src/datetime_funcs/second.rs b/native/spark-expr/src/datetime_funcs/second.rs new file mode 100644 index 0000000000..76a4dd9a2c --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/second.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct SecondExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for SecondExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for SecondExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl SecondExpr { + pub fn new(child: Arc, timezone: String) -> Self { + SecondExpr { child, timezone } + } +} + +impl Display for SecondExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Second (timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for SecondExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Second)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Second(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(SecondExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs new file mode 100644 index 0000000000..349992322f --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn}; + +#[derive(Debug, Eq)] +pub struct TimestampTruncExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc + format: Arc, + /// String containing a timezone name. The name must be found in the standard timezone + /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is + /// later parsed into a chrono::TimeZone. + /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) + /// along with a single value for the associated TimeZone. The timezone offset is applied + /// just before any operations on the timestamp + timezone: String, +} + +impl Hash for TimestampTruncExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.format.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for TimestampTruncExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.format.eq(&other.format) + && self.timezone.eq(&other.timezone) + } +} + +impl TimestampTruncExpr { + pub fn new( + child: Arc, + format: Arc, + timezone: String, + ) -> Self { + TimestampTruncExpr { + child, + format, + timezone, + } + } +} + +impl Display for TimestampTruncExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TimestampTrunc [child:{}, format:{}, timezone: {}]", + self.child, self.format, self.timezone + ) + } +} + +impl PhysicalExpr for TimestampTruncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema)? { + DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( + key_type, + Box::new(DataType::Timestamp(Microsecond, None)), + )), + _ => Ok(DataType::Timestamp(Microsecond, None)), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let timestamp = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + let tz = self.timezone.clone(); + match (timestamp, format) { + (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_dyn(&ts, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function TimestampTrunc. \ + Expected (PrimitiveArray, Scalar, String) or \ + (PrimitiveArray, StringArray, String)" + .to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(TimestampTruncExpr::new( + Arc::clone(&children[0]), + Arc::clone(&self.format), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/hash_funcs/mod.rs b/native/spark-expr/src/hash_funcs/mod.rs new file mode 100644 index 0000000000..7649c4c547 --- /dev/null +++ b/native/spark-expr/src/hash_funcs/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod murmur3; +mod sha2; +pub(super) mod utils; +mod xxhash64; + +pub use murmur3::spark_murmur3_hash; +pub use sha2::{spark_sha224, spark_sha256, spark_sha384, spark_sha512}; +pub use xxhash64::spark_xxhash64; diff --git a/native/spark-expr/src/hash_funcs/murmur3.rs b/native/spark-expr/src/hash_funcs/murmur3.rs new file mode 100644 index 0000000000..3ed70ba741 --- /dev/null +++ b/native/spark-expr/src/hash_funcs/murmur3.rs @@ -0,0 +1,280 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::create_hashes_internal; +use arrow::compute::take; +use arrow_array::types::ArrowDictionaryKeyType; +use arrow_array::{Array, ArrayRef, ArrowNativeTypeOp, DictionaryArray, Int32Array}; +use arrow_buffer::ArrowNativeType; +use datafusion_common::{internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible murmur3 hash (just `hash` in Spark) in vectorized execution fashion +pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u32; num_rows]; + hashes.fill(*seed as u32); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_murmur3_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( + hashes[0] as i32, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", + seed + ) + } + } +} + +/// Spark-compatible murmur3 hash function +#[inline] +pub fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { + #[inline] + fn mix_k1(mut k1: i32) -> i32 { + k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); + k1 = k1.rotate_left(15); + k1 = k1.mul_wrapping(0x1b873593u32 as i32); + k1 + } + + #[inline] + fn mix_h1(mut h1: i32, k1: i32) -> i32 { + h1 ^= k1; + h1 = h1.rotate_left(13); + h1 = h1.mul_wrapping(5).add_wrapping(0xe6546b64u32 as i32); + h1 + } + + #[inline] + fn fmix(mut h1: i32, len: i32) -> i32 { + h1 ^= len; + h1 ^= (h1 as u32 >> 16) as i32; + h1 = h1.mul_wrapping(0x85ebca6bu32 as i32); + h1 ^= (h1 as u32 >> 13) as i32; + h1 = h1.mul_wrapping(0xc2b2ae35u32 as i32); + h1 ^= (h1 as u32 >> 16) as i32; + h1 + } + + #[inline] + unsafe fn hash_bytes_by_int(data: &[u8], seed: u32) -> i32 { + // safety: data length must be aligned to 4 bytes + let mut h1 = seed as i32; + for i in (0..data.len()).step_by(4) { + let ints = data.as_ptr().add(i) as *const i32; + let mut half_word = ints.read_unaligned(); + if cfg!(target_endian = "big") { + half_word = half_word.reverse_bits(); + } + h1 = mix_h1(h1, mix_k1(half_word)); + } + h1 + } + let data = data.as_ref(); + let len = data.len(); + let len_aligned = len - len % 4; + + // safety: + // avoid boundary checking in performance critical codes. + // all operations are guaranteed to be safe + // data is &[u8] so we do not need to check for proper alignment + unsafe { + let mut h1 = if len_aligned > 0 { + hash_bytes_by_int(&data[0..len_aligned], seed) + } else { + seed as i32 + }; + + for i in len_aligned..len { + let half_word = *data.get_unchecked(i) as i8 as i32; + h1 = mix_h1(h1, mix_k1(half_word)); + } + fmix(h1, len as i32) as u32 + } +} + +/// Hash the values in a dictionary array +fn create_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u32], + first_col: bool, +) -> datafusion_common::Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + // unpack the dictionary array as each row may have a different hash input + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_murmur3_hashes(&[unpacked], hashes_buffer)?; + } else { + // For the first column, hash each dictionary value once, and then use + // that computed hash for each key value to avoid a potentially + // expensive redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42; dict_values.len()]; + create_murmur3_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +pub fn create_murmur3_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u32], +) -> datafusion_common::Result<&'a mut [u32]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_murmur3_hash, + create_hashes_dictionary + ); + Ok(hashes_buffer) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Float32Array, Float64Array}; + use std::sync::Arc; + + use crate::murmur3::create_murmur3_hashes; + use crate::test_hashes_with_nulls; + use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; + + fn test_murmur3_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_murmur3_hashes, T, values, expected, u32); + } + + #[test] + fn test_i8() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], + ); + } + + #[test] + fn test_i32() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], + ); + } + + #[test] + fn test_i64() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], + ); + } + + #[test] + fn test_f32() { + test_murmur3_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, + ], + ); + } + + #[test] + fn test_f64() { + test_murmur3_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, + ], + ); + } + + #[test] + fn test_str() { + let input = [ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + let expected: Vec = vec![ + 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, + 1322437556, 0xe860e5cc, 814637928, + ]; + + test_murmur3_hash::(input.clone(), expected); + } +} diff --git a/native/spark-expr/src/hash_funcs/sha2.rs b/native/spark-expr/src/hash_funcs/sha2.rs new file mode 100644 index 0000000000..09422ea9b5 --- /dev/null +++ b/native/spark-expr/src/hash_funcs/sha2.rs @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::hex::hex_strings; + +use arrow_array::{Array, StringArray}; +use datafusion::functions::crypto::{sha224, sha256, sha384, sha512}; +use datafusion_common::cast::as_binary_array; +use datafusion_common::{exec_err, DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF}; +use std::sync::Arc; + +/// `sha224` function that simulates Spark's `sha2` expression with bit width 224 +pub fn spark_sha224(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha224()) +} + +/// `sha256` function that simulates Spark's `sha2` expression with bit width 0 or 256 +pub fn spark_sha256(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha256()) +} + +/// `sha384` function that simulates Spark's `sha2` expression with bit width 384 +pub fn spark_sha384(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha384()) +} + +/// `sha512` function that simulates Spark's `sha2` expression with bit width 512 +pub fn spark_sha512(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha512()) +} + +// Spark requires hex string as the result of sha2 functions, we have to wrap the +// result of digest functions as hex string +fn wrap_digest_result_as_hex_string( + args: &[ColumnarValue], + digest: Arc, +) -> Result { + let row_count = match &args[0] { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }; + let value = digest.invoke_batch(args, row_count)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_strings::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_strings::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} diff --git a/native/spark-expr/src/hash_funcs/utils.rs b/native/spark-expr/src/hash_funcs/utils.rs new file mode 100644 index 0000000000..07ba1952d7 --- /dev/null +++ b/native/spark-expr/src/hash_funcs/utils.rs @@ -0,0 +1,393 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This includes utilities for hashing and murmur3 hashing. + +#[macro_export] +macro_rules! hash_array { + ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method(&array.value(i), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method(&array.value(i), *hash); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = + $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_primitive { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_primitive_float { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if *value == 0.0 && value.is_sign_negative() { + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); + } else { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if *value == 0.0 && value.is_sign_negative() { + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); + } else { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_decimal { + ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); + } + } + } + }; +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +/// +/// `hash_method` is the hash function to use. +/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input. +#[macro_export] +macro_rules! create_hashes_internal { + ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => { + use arrow::datatypes::{DataType, TimeUnit}; + use arrow_array::{types::*, *}; + + for (i, col) in $arrays.iter().enumerate() { + let first_col = i == 0; + match col.data_type() { + DataType::Boolean => { + $crate::hash_array_boolean!( + BooleanArray, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int8 => { + $crate::hash_array_primitive!( + Int8Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int16 => { + $crate::hash_array_primitive!( + Int16Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int32 => { + $crate::hash_array_primitive!( + Int32Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int64 => { + $crate::hash_array_primitive!( + Int64Array, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Float32 => { + $crate::hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Float64 => { + $crate::hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Second, _) => { + $crate::hash_array_primitive!( + TimestampSecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + $crate::hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + $crate::hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + $crate::hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Date32 => { + $crate::hash_array_primitive!( + Date32Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Date64 => { + $crate::hash_array_primitive!( + Date64Array, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Utf8 => { + $crate::hash_array!(StringArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeUtf8 => { + $crate::hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method); + } + DataType::Binary => { + $crate::hash_array!(BinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeBinary => { + $crate::hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::FixedSizeBinary(_) => { + $crate::hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::Decimal128(_, _) => { + $crate::hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => { + $create_dictionary_hash_method::(col, $hashes_buffer, first_col)?; + } + DataType::Int16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt8 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported dictionary type in hasher hashing: {}", + col.data_type(), + ))) + } + }, + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {}", + col.data_type() + ))); + } + } + } + }; +} + +pub(crate) mod test_utils { + + #[macro_export] + macro_rules! test_hashes_internal { + ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => { + let i = $input; + let mut hashes = $initial_seeds.clone(); + $hash_method(&[i], &mut hashes).unwrap(); + assert_eq!(hashes, $expected); + }; + } + + #[macro_export] + macro_rules! test_hashes_with_nulls { + ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => { + // copied before inserting nulls + let mut input_with_nulls = $values.clone(); + let mut expected_with_nulls = $expected.clone(); + // test before inserting nulls + let len = $values.len(); + let initial_seeds = vec![42 as $seed_type; len]; + let i = Arc::new(<$t>::from($values)) as ArrayRef; + $crate::test_hashes_internal!($method, i, initial_seeds, $expected); + + // test with nulls + let median = len / 2; + input_with_nulls.insert(0, None); + input_with_nulls.insert(median, None); + expected_with_nulls.insert(0, 42 as $seed_type); + expected_with_nulls.insert(median, 42 as $seed_type); + let len_with_nulls = len + 2; + let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls]; + let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef; + $crate::test_hashes_internal!( + $method, + nullable_input, + initial_seeds_with_nulls, + expected_with_nulls + ); + }; + } +} diff --git a/native/spark-expr/src/hash_funcs/xxhash64.rs b/native/spark-expr/src/hash_funcs/xxhash64.rs new file mode 100644 index 0000000000..e96f178d83 --- /dev/null +++ b/native/spark-expr/src/hash_funcs/xxhash64.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::take; +use twox_hash::XxHash64; + +use datafusion::{ + arrow::{ + array::*, + datatypes::{ArrowDictionaryKeyType, ArrowNativeType}, + }, + common::{internal_err, ScalarValue}, + error::{DataFusionError, Result}, +}; + +use crate::create_hashes_internal; +use arrow_array::{Array, ArrayRef, Int64Array}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible xxhash64 in vectorized execution fashion +pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} + +#[inline] +fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { + XxHash64::oneshot(seed, data.as_ref()) +} + +// Hash the values in a dictionary array using xxhash64 +fn create_xxhash64_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u64], + first_col: bool, +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_xxhash64_hashes(&[unpacked], hashes_buffer)?; + } else { + // Hash each dictionary value once, and then use that computed + // hash for each key value to avoid a potentially expensive + // redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42u64; dict_values.len()]; + create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; + + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +/// Creates xxhash64 hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +fn create_xxhash64_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_xxhash64, + create_xxhash64_hashes_dictionary + ); + Ok(hashes_buffer) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Float32Array, Float64Array}; + use std::sync::Arc; + + use super::create_xxhash64_hashes; + use crate::test_hashes_with_nulls; + use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; + + fn test_xxhash64_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_xxhash64_hashes, T, values, expected, u64); + } + + #[test] + fn test_i8() { + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x77cc15d9f9f2cdc2, + 0x39bc22b9e94d81d0, + ], + ); + } + + #[test] + fn test_i32() { + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x14f0ac009c21721c, + 0x1cc7cb8d034769cd, + ], + ); + } + + #[test] + fn test_i64() { + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![ + 0x9ed50fd59358d232, + 0xb71b47ebda15746c, + 0x358ae035bfb46fd2, + 0xd2f1c616ae7eb306, + 0x88608019c494c1f4, + ], + ); + } + + #[test] + fn test_f32() { + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0x9b92689757fcdbd, + 0x3229fbc4681e48f3, + 0x3229fbc4681e48f3, + 0xa2becc0e61bb3823, + 0x8f20ab82d4f3687f, + 0xdce4982d97f7ac4, + ], + ) + } + + #[test] + fn test_f64() { + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe1fd6e07fee8ad53, + 0xb71b47ebda15746c, + 0xb71b47ebda15746c, + 0x8cdde022746f8f1f, + 0x793c5c88d313eac7, + 0xc5e60e7b75d9b232, + ], + ) + } + + #[test] + fn test_str() { + let input = [ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + + test_xxhash64_hash::( + input, + vec![ + 0xc3629e6318d53932, + 0xe7097b6a54378d8a, + 0x98b1582b0977e704, + 0xa80d9d5a6a523bd5, + 0xfcba5f61ac666c61, + 0x88e4fe59adf7b0cc, + 0x259dd873209a3fe3, + 0x13c1d910702770e6, + 0xa17b5eb5dc364dff, + 0xf241303e4a90f299, + ], + ) + } +} diff --git a/native/spark-expr/src/json_funcs/mod.rs b/native/spark-expr/src/json_funcs/mod.rs new file mode 100644 index 0000000000..de3037590d --- /dev/null +++ b/native/spark-expr/src/json_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod to_json; + +pub use to_json::ToJson; diff --git a/native/spark-expr/src/to_json.rs b/native/spark-expr/src/json_funcs/to_json.rs similarity index 99% rename from native/spark-expr/src/to_json.rs rename to native/spark-expr/src/json_funcs/to_json.rs index 91b46c6f04..3389ea3a0e 100644 --- a/native/spark-expr/src/to_json.rs +++ b/native/spark-expr/src/json_funcs/to_json.rs @@ -19,7 +19,7 @@ // of the Spark-specific compatibility features that we need (including // being able to specify Spark-compatible cast from all types to string) -use crate::cast::SparkCastOptions; +use crate::SparkCastOptions; use crate::{spark_cast, EvalMode}; use arrow_array::builder::StringBuilder; use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, StructArray}; @@ -250,7 +250,7 @@ fn struct_to_json(array: &StructArray, timezone: &str) -> Result { #[cfg(test)] mod test { - use crate::to_json::struct_to_json; + use crate::json_funcs::to_json::struct_to_json; use arrow_array::types::Int32Type; use arrow_array::{Array, PrimitiveArray, StringArray}; use arrow_array::{ArrayRef, BooleanArray, Int32Array, StructArray}; diff --git a/native/spark-expr/src/kernels/mod.rs b/native/spark-expr/src/kernels/mod.rs index 3669ff13ad..88aa34b1a3 100644 --- a/native/spark-expr/src/kernels/mod.rs +++ b/native/spark-expr/src/kernels/mod.rs @@ -17,5 +17,4 @@ //! Kernels -pub mod strings; pub(crate) mod temporal; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f358731004..b6353075b4 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -19,62 +19,48 @@ // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] -mod cast; mod error; -mod if_expr; -mod avg; -pub use avg::Avg; -mod bitwise_not; -pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; -mod avg_decimal; -pub use avg_decimal::AvgDecimal; -mod checkoverflow; -pub use checkoverflow::CheckOverflow; -mod correlation; -pub use correlation::Correlation; -mod covariance; -pub use covariance::Covariance; -mod strings; -pub use strings::{Contains, EndsWith, Like, StartsWith, StringSpaceExpr, SubstringExpr}; mod kernels; -mod list; -mod regexp; -pub mod scalar_funcs; mod schema_adapter; pub use schema_adapter::SparkSchemaAdapterFactory; -pub mod spark_hash; -mod stddev; -pub use stddev::Stddev; -mod structs; -mod sum_decimal; -pub use sum_decimal::SumDecimal; -mod negative; -pub use negative::{create_negate_expr, NegativeExpr}; -mod normalize_nan; -mod temporal; - pub mod test_common; pub mod timezone; -mod to_json; mod unbound; pub use unbound::UnboundColumn; pub mod utils; -pub use normalize_nan::NormalizeNaNAndZero; -mod variance; -pub use variance::Variance; +mod agg_funcs; +mod array_funcs; +mod bitwise_funcs; mod comet_scalar_funcs; -pub use cast::{spark_cast, Cast, SparkCastOptions}; +mod conditional_funcs; +mod conversion_funcs; +mod datetime_funcs; +pub mod hash_funcs; +mod json_funcs; +mod math_funcs; +mod predicate_funcs; +mod static_invoke; +mod string_funcs; +mod struct_funcs; + +pub use agg_funcs::*; +pub use array_funcs::*; +pub use bitwise_funcs::*; pub use comet_scalar_funcs::create_comet_physical_fun; +pub use conditional_funcs::*; +pub use conversion_funcs::*; +pub use datetime_funcs::*; pub use error::{SparkError, SparkResult}; -pub use if_expr::IfExpr; -pub use list::{ArrayInsert, GetArrayStructFields, ListExtract}; -pub use regexp::RLike; -pub use structs::{CreateNamedStruct, GetStructField}; -pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; -pub use to_json::ToJson; +pub use hash_funcs::*; +pub use json_funcs::*; +pub use math_funcs::*; +pub use predicate_funcs::*; +pub use static_invoke::*; +pub use string_funcs::*; +pub use struct_funcs::*; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an diff --git a/native/spark-expr/src/math_funcs/ceil.rs b/native/spark-expr/src/math_funcs/ceil.rs new file mode 100644 index 0000000000..9c0fc9b571 --- /dev/null +++ b/native/spark-expr/src/math_funcs/ceil.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::downcast_compute_op; +use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar}; +use arrow::array::{Float32Array, Float64Array, Int64Array}; +use arrow_array::{Array, ArrowNativeTypeOp}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use num::integer::div_ceil; +use std::sync::Arc; + +/// `ceil` function that simulates Spark `ceil` expression +pub fn spark_ceil( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Int64 => { + let result = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(Arc::new(result.clone()))) + } + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_ceil_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ceil", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.ceil() as i64), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.ceil() as i64), + ))), + ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), + ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { + let f = decimal_ceil_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ceil", + value.data_type(), + ))), + }, + } +} + +#[inline] +fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 { + let div = 10_i128.pow_wrapping(*scale as u32); + move |x: i128| div_ceil(x, div) +} diff --git a/native/spark-expr/src/math_funcs/div.rs b/native/spark-expr/src/math_funcs/div.rs new file mode 100644 index 0000000000..72c23b9e9b --- /dev/null +++ b/native/spark-expr/src/math_funcs/div.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::utils::get_precision_scale; +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::Decimal128Type, +}; +use arrow_array::{Array, Decimal128Array}; +use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::DataFusionError; +use num::{BigInt, Signed, ToPrimitive}; +use std::sync::Arc; + +// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). +// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to +// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since +// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale > +// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt. +pub fn spark_decimal_div( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let left = &args[0]; + let right = &args[1]; + let (p3, s3) = get_precision_scale(data_type); + + let (left, right): (ArrayRef, ArrayRef) = match (left, right) { + (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), + (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { + (l.to_array_of_size(r.len())?, Arc::clone(r)) + } + (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { + (Arc::clone(l), r.to_array_of_size(l.len())?) + } + (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), + }; + let left = left.as_primitive::(); + let right = right.as_primitive::(); + let (p1, s1) = get_precision_scale(left.data_type()); + let (p2, s2) = get_precision_scale(right.data_type()); + + let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32); + let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32); + let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32 + || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32 + { + let ten = BigInt::from(10); + let l_mul = ten.pow(l_exp); + let r_mul = ten.pow(r_exp); + let five = BigInt::from(5); + let zero = BigInt::from(0); + arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = BigInt::from(l) * &l_mul; + let r = BigInt::from(r) * &r_mul; + let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; + let res = if div.is_negative() { + div - &five + } else { + div + &five + } / &ten; + res.to_i128().unwrap_or(i128::MAX) + })? + } else { + let l_mul = 10_i128.pow(l_exp); + let r_mul = 10_i128.pow(r_exp); + arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = l * l_mul; + let r = r * r_mul; + let div = if r == 0 { 0 } else { l / r }; + let res = if div.is_negative() { div - 5 } else { div + 5 } / 10; + res.to_i128().unwrap_or(i128::MAX) + })? + }; + let result = result.with_data_type(DataType::Decimal128(p3, s3)); + Ok(ColumnarValue::Array(Arc::new(result))) +} diff --git a/native/spark-expr/src/math_funcs/floor.rs b/native/spark-expr/src/math_funcs/floor.rs new file mode 100644 index 0000000000..9a95d95afe --- /dev/null +++ b/native/spark-expr/src/math_funcs/floor.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::downcast_compute_op; +use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar}; +use arrow::array::{Float32Array, Float64Array, Int64Array}; +use arrow_array::{Array, ArrowNativeTypeOp}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use num::integer::div_floor; +use std::sync::Arc; + +/// `floor` function that simulates Spark `floor` expression +pub fn spark_floor( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Int64 => { + let result = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(Arc::new(result.clone()))) + } + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_floor_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function floor", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.floor() as i64), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.floor() as i64), + ))), + ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), + ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { + let f = decimal_floor_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function floor", + value.data_type(), + ))), + }, + } +} + +#[inline] +fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 { + let div = 10_i128.pow_wrapping(*scale as u32); + move |x: i128| div_floor(x, div) +} diff --git a/native/spark-expr/src/scalar_funcs/hex.rs b/native/spark-expr/src/math_funcs/hex.rs similarity index 99% rename from native/spark-expr/src/scalar_funcs/hex.rs rename to native/spark-expr/src/math_funcs/hex.rs index e572ba5ef3..4ccd4f4538 100644 --- a/native/spark-expr/src/scalar_funcs/hex.rs +++ b/native/spark-expr/src/math_funcs/hex.rs @@ -52,7 +52,7 @@ fn hex_encode>(data: T, lower_case: bool) -> String { } #[inline(always)] -pub(super) fn hex_strings>(data: T) -> String { +pub(crate) fn hex_strings>(data: T) -> String { hex_encode(data, true) } diff --git a/native/spark-expr/src/checkoverflow.rs b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs similarity index 100% rename from native/spark-expr/src/checkoverflow.rs rename to native/spark-expr/src/math_funcs/internal/checkoverflow.rs diff --git a/native/spark-expr/src/math_funcs/internal/make_decimal.rs b/native/spark-expr/src/math_funcs/internal/make_decimal.rs new file mode 100644 index 0000000000..dd761cd69f --- /dev/null +++ b/native/spark-expr/src/math_funcs/internal/make_decimal.rs @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::utils::get_precision_scale; +use arrow::{ + array::{AsArray, Decimal128Builder}, + datatypes::{validate_decimal_precision, Int64Type}, +}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer) +pub fn spark_make_decimal( + args: &[ColumnarValue], + data_type: &DataType, +) -> DataFusionResult { + let (precision, scale) = get_precision_scale(data_type); + match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + long_to_decimal(n, precision), + precision, + scale, + ))), + sv => internal_err!("Expected Int64 but found {sv:?}"), + }, + ColumnarValue::Array(a) => { + let arr = a.as_primitive::(); + let mut result = Decimal128Builder::new(); + for v in arr.into_iter() { + result.append_option(long_to_decimal(&v, precision)) + } + let result_type = DataType::Decimal128(precision, scale); + + Ok(ColumnarValue::Array(Arc::new( + result.finish().with_data_type(result_type), + ))) + } + } +} + +/// Convert the input long to decimal with the given maximum precision. If overflows, returns null +/// instead. +#[inline] +fn long_to_decimal(v: &Option, precision: u8) -> Option { + match v { + Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => Some(*v as i128), + _ => None, + } +} diff --git a/native/spark-expr/src/math_funcs/internal/mod.rs b/native/spark-expr/src/math_funcs/internal/mod.rs new file mode 100644 index 0000000000..29295f0d52 --- /dev/null +++ b/native/spark-expr/src/math_funcs/internal/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod checkoverflow; +mod make_decimal; +mod normalize_nan; +mod unscaled_value; + +pub use checkoverflow::CheckOverflow; +pub use make_decimal::spark_make_decimal; +pub use normalize_nan::NormalizeNaNAndZero; +pub use unscaled_value::spark_unscaled_value; diff --git a/native/spark-expr/src/normalize_nan.rs b/native/spark-expr/src/math_funcs/internal/normalize_nan.rs similarity index 100% rename from native/spark-expr/src/normalize_nan.rs rename to native/spark-expr/src/math_funcs/internal/normalize_nan.rs diff --git a/native/spark-expr/src/math_funcs/internal/unscaled_value.rs b/native/spark-expr/src/math_funcs/internal/unscaled_value.rs new file mode 100644 index 0000000000..053f9b078f --- /dev/null +++ b/native/spark-expr/src/math_funcs/internal/unscaled_value.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{AsArray, Int64Builder}, + datatypes::Decimal128Type, +}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `UnscaledValue` expression (internal to Spark optimizer) +pub fn spark_unscaled_value(args: &[ColumnarValue]) -> DataFusionResult { + match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Decimal128(d, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + d.map(|n| n as i64), + ))), + dt => internal_err!("Expected Decimal128 but found {dt:}"), + }, + ColumnarValue::Array(a) => { + let arr = a.as_primitive::(); + let mut result = Int64Builder::new(); + for v in arr.into_iter() { + result.append_option(v.map(|v| v as i64)); + } + Ok(ColumnarValue::Array(Arc::new(result.finish()))) + } + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs new file mode 100644 index 0000000000..c559ae15c0 --- /dev/null +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod ceil; +mod div; +mod floor; +pub(crate) mod hex; +pub mod internal; +mod negative; +mod round; +pub(crate) mod unhex; +mod utils; + +pub use ceil::spark_ceil; +pub use div::spark_decimal_div; +pub use floor::spark_floor; +pub use hex::spark_hex; +pub use internal::*; +pub use negative::{create_negate_expr, NegativeExpr}; +pub use round::spark_round; +pub use unhex::spark_unhex; diff --git a/native/spark-expr/src/negative.rs b/native/spark-expr/src/math_funcs/negative.rs similarity index 99% rename from native/spark-expr/src/negative.rs rename to native/spark-expr/src/math_funcs/negative.rs index 7fb5089179..cafbcfcbdb 100644 --- a/native/spark-expr/src/negative.rs +++ b/native/spark-expr/src/math_funcs/negative.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use super::arithmetic_overflow_error; +use crate::arithmetic_overflow_error; use crate::SparkError; use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType}; use arrow_array::RecordBatch; diff --git a/native/spark-expr/src/math_funcs/round.rs b/native/spark-expr/src/math_funcs/round.rs new file mode 100644 index 0000000000..a47b7bc294 --- /dev/null +++ b/native/spark-expr/src/math_funcs/round.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::utils::{get_precision_scale, make_decimal_array, make_decimal_scalar}; +use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array}; +use arrow_array::{Array, ArrowNativeTypeOp}; +use arrow_schema::DataType; +use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; +use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue}; +use std::{cmp::min, sync::Arc}; + +macro_rules! integer_round { + ($X:expr, $DIV:expr, $HALF:expr) => {{ + let rem = $X % $DIV; + if rem <= -$HALF { + ($X - rem).sub_wrapping($DIV) + } else if rem >= $HALF { + ($X - rem).add_wrapping($DIV) + } else { + $X - rem + } + }}; +} + +macro_rules! round_integer_array { + ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{ + let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); + let ten: $NATIVE = 10; + let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { + let half = div / 2; + arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half)) + } else { + arrow::compute::kernels::arity::unary(array, |_| 0) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! round_integer_scalar { + ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{ + let ten: $NATIVE = 10; + if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { + let half = div / 2; + Ok(ColumnarValue::Scalar($TYPE( + $SCALAR.map(|x| integer_round!(x, div, half)), + ))) + } else { + Ok(ColumnarValue::Scalar($TYPE(Some(0)))) + } + }}; +} + +/// `round` function that simulates Spark `round` expression +pub fn spark_round( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + let point = &args[1]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else { + return internal_err!("Invalid point argument for Round(): {:#?}", point); + }; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64), + DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32), + DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16), + DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8), + DataType::Decimal128(_, scale) if *scale >= 0 => { + let f = decimal_round_f(scale, point); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + DataType::Float32 | DataType::Float64 => { + Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?)) + } + dt => exec_err!("Not supported datatype for ROUND: {dt}"), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Int64(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int64, i64) + } + ScalarValue::Int32(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int32, i32) + } + ScalarValue::Int16(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int16, i16) + } + ScalarValue::Int8(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int8, i8) + } + ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => { + let f = decimal_round_f(scale, point); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar( + ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?, + )), + dt => exec_err!("Not supported datatype for ROUND: {dt}"), + }, + } +} + +// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by +// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication +#[inline] +fn decimal_round_f(scale: &i8, point: &i64) -> Box i128> { + if *point < 0 { + if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) { + let half = div / 2; + let mul = 10_i128.pow_wrapping((-(*point)) as u32); + // i128 can hold 39 digits of a base 10 number, adding half will not cause overflow + Box::new(move |x: i128| (x + x.signum() * half) / div * mul) + } else { + Box::new(move |_: i128| 0) + } + } else { + let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32)); + let half = div / 2; + Box::new(move |x: i128| (x + x.signum() * half) / div) + } +} diff --git a/native/spark-expr/src/scalar_funcs/unhex.rs b/native/spark-expr/src/math_funcs/unhex.rs similarity index 100% rename from native/spark-expr/src/scalar_funcs/unhex.rs rename to native/spark-expr/src/math_funcs/unhex.rs diff --git a/native/spark-expr/src/math_funcs/utils.rs b/native/spark-expr/src/math_funcs/utils.rs new file mode 100644 index 0000000000..204b7139e4 --- /dev/null +++ b/native/spark-expr/src/math_funcs/utils.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::Decimal128Type; +use arrow_array::{ArrayRef, Decimal128Array}; +use arrow_schema::DataType; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use std::sync::Arc; + +#[macro_export] +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = + arrow::compute::kernels::arity::unary(array, |x| x.$FUNC() as i64); + Ok(Arc::new(res)) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +#[inline] +pub(crate) fn make_decimal_scalar( + a: &Option, + precision: u8, + scale: i8, + f: &dyn Fn(i128) -> i128, +) -> Result { + let result = ScalarValue::Decimal128(a.map(f), precision, scale); + Ok(ColumnarValue::Scalar(result)) +} + +#[inline] +pub(crate) fn make_decimal_array( + array: &ArrayRef, + precision: u8, + scale: i8, + f: &dyn Fn(i128) -> i128, +) -> Result { + let array = array.as_primitive::(); + let result: Decimal128Array = arrow::compute::kernels::arity::unary(array, f); + let result = result.with_data_type(DataType::Decimal128(precision, scale)); + Ok(ColumnarValue::Array(Arc::new(result))) +} + +#[inline] +pub(crate) fn get_precision_scale(data_type: &DataType) -> (u8, i8) { + let DataType::Decimal128(precision, scale) = data_type else { + unreachable!() + }; + (*precision, *scale) +} diff --git a/native/spark-expr/src/predicate_funcs/is_nan.rs b/native/spark-expr/src/predicate_funcs/is_nan.rs new file mode 100644 index 0000000000..bf4d7e0f26 --- /dev/null +++ b/native/spark-expr/src/predicate_funcs/is_nan.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Float32Array, Float64Array}; +use arrow_array::{Array, BooleanArray}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `isnan` expression +pub fn spark_isnan(args: &[ColumnarValue]) -> Result { + fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue { + match is_nan.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + ColumnarValue::Array(Arc::new(BooleanArray::new( + is_nan.values() & is_not_null, + None, + ))) + } + None => ColumnarValue::Array(Arc::new(is_nan)), + } + } + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + value.data_type(), + ))), + }, + } +} diff --git a/native/spark-expr/src/predicate_funcs/mod.rs b/native/spark-expr/src/predicate_funcs/mod.rs new file mode 100644 index 0000000000..5f1f570c05 --- /dev/null +++ b/native/spark-expr/src/predicate_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod is_nan; +mod rlike; + +pub use is_nan::spark_isnan; +pub use rlike::RLike; diff --git a/native/spark-expr/src/regexp.rs b/native/spark-expr/src/predicate_funcs/rlike.rs similarity index 100% rename from native/spark-expr/src/regexp.rs rename to native/spark-expr/src/predicate_funcs/rlike.rs diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs deleted file mode 100644 index 2961f038dc..0000000000 --- a/native/spark-expr/src/scalar_funcs.rs +++ /dev/null @@ -1,626 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::compute::kernels::numeric::{add, sub}; -use arrow::datatypes::IntervalDayTime; -use arrow::{ - array::{ - ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int64Builder, Int8Array, OffsetSizeTrait, - }, - datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, -}; -use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder}; -use arrow_array::types::{Int16Type, Int32Type, Int8Type}; -use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array}; -use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION}; -use datafusion::physical_expr_common::datum; -use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; -use datafusion_common::{ - cast::as_generic_string_array, exec_err, internal_err, DataFusionError, - Result as DataFusionResult, ScalarValue, -}; -use num::{ - integer::{div_ceil, div_floor}, - BigInt, Signed, ToPrimitive, -}; -use std::fmt::Write; -use std::{cmp::min, sync::Arc}; - -mod unhex; -pub use unhex::spark_unhex; - -mod hex; -pub use hex::spark_hex; - -mod chr; -pub use chr::SparkChrFunc; - -pub mod hash_expressions; -// exposed for benchmark only -pub use hash_expressions::{spark_murmur3_hash, spark_xxhash64}; - -#[inline] -fn get_precision_scale(data_type: &DataType) -> (u8, i8) { - let DataType::Decimal128(precision, scale) = data_type else { - unreachable!() - }; - (*precision, *scale) -} - -macro_rules! downcast_compute_op { - ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ - let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); - match n { - Some(array) => { - let res: $RESULT = - arrow::compute::kernels::arity::unary(array, |x| x.$FUNC() as i64); - Ok(Arc::new(res)) - } - _ => Err(DataFusionError::Internal(format!( - "Invalid data type for {}", - $NAME - ))), - } - }}; -} - -/// `ceil` function that simulates Spark `ceil` expression -pub fn spark_ceil( - args: &[ColumnarValue], - data_type: &DataType, -) -> Result { - let value = &args[0]; - match value { - ColumnarValue::Array(array) => match array.data_type() { - DataType::Float32 => { - let result = downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array); - Ok(ColumnarValue::Array(result?)) - } - DataType::Float64 => { - let result = downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array); - Ok(ColumnarValue::Array(result?)) - } - DataType::Int64 => { - let result = array.as_any().downcast_ref::().unwrap(); - Ok(ColumnarValue::Array(Arc::new(result.clone()))) - } - DataType::Decimal128(_, scale) if *scale > 0 => { - let f = decimal_ceil_f(scale); - let (precision, scale) = get_precision_scale(data_type); - make_decimal_array(array, precision, scale, &f) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function ceil", - other, - ))), - }, - ColumnarValue::Scalar(a) => match a { - ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - a.map(|x| x.ceil() as i64), - ))), - ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - a.map(|x| x.ceil() as i64), - ))), - ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), - ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { - let f = decimal_ceil_f(scale); - let (precision, scale) = get_precision_scale(data_type); - make_decimal_scalar(a, precision, scale, &f) - } - _ => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function ceil", - value.data_type(), - ))), - }, - } -} - -/// `floor` function that simulates Spark `floor` expression -pub fn spark_floor( - args: &[ColumnarValue], - data_type: &DataType, -) -> Result { - let value = &args[0]; - match value { - ColumnarValue::Array(array) => match array.data_type() { - DataType::Float32 => { - let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array); - Ok(ColumnarValue::Array(result?)) - } - DataType::Float64 => { - let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array); - Ok(ColumnarValue::Array(result?)) - } - DataType::Int64 => { - let result = array.as_any().downcast_ref::().unwrap(); - Ok(ColumnarValue::Array(Arc::new(result.clone()))) - } - DataType::Decimal128(_, scale) if *scale > 0 => { - let f = decimal_floor_f(scale); - let (precision, scale) = get_precision_scale(data_type); - make_decimal_array(array, precision, scale, &f) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function floor", - other, - ))), - }, - ColumnarValue::Scalar(a) => match a { - ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - a.map(|x| x.floor() as i64), - ))), - ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - a.map(|x| x.floor() as i64), - ))), - ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), - ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { - let f = decimal_floor_f(scale); - let (precision, scale) = get_precision_scale(data_type); - make_decimal_scalar(a, precision, scale, &f) - } - _ => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function floor", - value.data_type(), - ))), - }, - } -} - -/// Spark-compatible `UnscaledValue` expression (internal to Spark optimizer) -pub fn spark_unscaled_value(args: &[ColumnarValue]) -> DataFusionResult { - match &args[0] { - ColumnarValue::Scalar(v) => match v { - ScalarValue::Decimal128(d, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( - d.map(|n| n as i64), - ))), - dt => internal_err!("Expected Decimal128 but found {dt:}"), - }, - ColumnarValue::Array(a) => { - let arr = a.as_primitive::(); - let mut result = Int64Builder::new(); - for v in arr.into_iter() { - result.append_option(v.map(|v| v as i64)); - } - Ok(ColumnarValue::Array(Arc::new(result.finish()))) - } - } -} - -/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer) -pub fn spark_make_decimal( - args: &[ColumnarValue], - data_type: &DataType, -) -> DataFusionResult { - let (precision, scale) = get_precision_scale(data_type); - match &args[0] { - ColumnarValue::Scalar(v) => match v { - ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - long_to_decimal(n, precision), - precision, - scale, - ))), - sv => internal_err!("Expected Int64 but found {sv:?}"), - }, - ColumnarValue::Array(a) => { - let arr = a.as_primitive::(); - let mut result = Decimal128Builder::new(); - for v in arr.into_iter() { - result.append_option(long_to_decimal(&v, precision)) - } - let result_type = DataType::Decimal128(precision, scale); - - Ok(ColumnarValue::Array(Arc::new( - result.finish().with_data_type(result_type), - ))) - } - } -} - -/// Convert the input long to decimal with the given maximum precision. If overflows, returns null -/// instead. -#[inline] -fn long_to_decimal(v: &Option, precision: u8) -> Option { - match v { - Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => Some(*v as i128), - _ => None, - } -} - -#[inline] -fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 { - let div = 10_i128.pow_wrapping(*scale as u32); - move |x: i128| div_ceil(x, div) -} - -#[inline] -fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 { - let div = 10_i128.pow_wrapping(*scale as u32); - move |x: i128| div_floor(x, div) -} - -// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by -// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication -#[inline] -fn decimal_round_f(scale: &i8, point: &i64) -> Box i128> { - if *point < 0 { - if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) { - let half = div / 2; - let mul = 10_i128.pow_wrapping((-(*point)) as u32); - // i128 can hold 39 digits of a base 10 number, adding half will not cause overflow - Box::new(move |x: i128| (x + x.signum() * half) / div * mul) - } else { - Box::new(move |_: i128| 0) - } - } else { - let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32)); - let half = div / 2; - Box::new(move |x: i128| (x + x.signum() * half) / div) - } -} - -#[inline] -fn make_decimal_array( - array: &ArrayRef, - precision: u8, - scale: i8, - f: &dyn Fn(i128) -> i128, -) -> Result { - let array = array.as_primitive::(); - let result: Decimal128Array = arrow::compute::kernels::arity::unary(array, f); - let result = result.with_data_type(DataType::Decimal128(precision, scale)); - Ok(ColumnarValue::Array(Arc::new(result))) -} - -#[inline] -fn make_decimal_scalar( - a: &Option, - precision: u8, - scale: i8, - f: &dyn Fn(i128) -> i128, -) -> Result { - let result = ScalarValue::Decimal128(a.map(f), precision, scale); - Ok(ColumnarValue::Scalar(result)) -} - -macro_rules! integer_round { - ($X:expr, $DIV:expr, $HALF:expr) => {{ - let rem = $X % $DIV; - if rem <= -$HALF { - ($X - rem).sub_wrapping($DIV) - } else if rem >= $HALF { - ($X - rem).add_wrapping($DIV) - } else { - $X - rem - } - }}; -} - -macro_rules! round_integer_array { - ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{ - let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); - let ten: $NATIVE = 10; - let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { - let half = div / 2; - arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half)) - } else { - arrow::compute::kernels::arity::unary(array, |_| 0) - }; - Ok(ColumnarValue::Array(Arc::new(result))) - }}; -} - -macro_rules! round_integer_scalar { - ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{ - let ten: $NATIVE = 10; - if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { - let half = div / 2; - Ok(ColumnarValue::Scalar($TYPE( - $SCALAR.map(|x| integer_round!(x, div, half)), - ))) - } else { - Ok(ColumnarValue::Scalar($TYPE(Some(0)))) - } - }}; -} - -/// `round` function that simulates Spark `round` expression -pub fn spark_round( - args: &[ColumnarValue], - data_type: &DataType, -) -> Result { - let value = &args[0]; - let point = &args[1]; - let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else { - return internal_err!("Invalid point argument for Round(): {:#?}", point); - }; - match value { - ColumnarValue::Array(array) => match array.data_type() { - DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64), - DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32), - DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16), - DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8), - DataType::Decimal128(_, scale) if *scale >= 0 => { - let f = decimal_round_f(scale, point); - let (precision, scale) = get_precision_scale(data_type); - make_decimal_array(array, precision, scale, &f) - } - DataType::Float32 | DataType::Float64 => { - Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?)) - } - dt => exec_err!("Not supported datatype for ROUND: {dt}"), - }, - ColumnarValue::Scalar(a) => match a { - ScalarValue::Int64(a) if *point < 0 => { - round_integer_scalar!(a, point, ScalarValue::Int64, i64) - } - ScalarValue::Int32(a) if *point < 0 => { - round_integer_scalar!(a, point, ScalarValue::Int32, i32) - } - ScalarValue::Int16(a) if *point < 0 => { - round_integer_scalar!(a, point, ScalarValue::Int16, i16) - } - ScalarValue::Int8(a) if *point < 0 => { - round_integer_scalar!(a, point, ScalarValue::Int8, i8) - } - ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => { - let f = decimal_round_f(scale, point); - let (precision, scale) = get_precision_scale(data_type); - make_decimal_scalar(a, precision, scale, &f) - } - ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar( - ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?, - )), - dt => exec_err!("Not supported datatype for ROUND: {dt}"), - }, - } -} - -/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length -pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { - match args { - [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { - match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::(array, *length), - DataType::LargeUtf8 => spark_read_side_padding_internal::(array, *length), - // TODO: handle Dictionary types - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function read_side_padding", - ))), - } - } - other => Err(DataFusionError::Internal(format!( - "Unsupported arguments {other:?} for function read_side_padding", - ))), - } -} - -fn spark_read_side_padding_internal( - array: &ArrayRef, - length: i32, -) -> Result { - let string_array = as_generic_string_array::(array)?; - let length = 0.max(length) as usize; - let space_string = " ".repeat(length); - - let mut builder = - GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); - - for string in string_array.iter() { - match string { - Some(string) => { - // It looks Spark's UTF8String is closer to chars rather than graphemes - // https://stackoverflow.com/a/46290728 - let char_len = string.chars().count(); - if length <= char_len { - builder.append_value(string); - } else { - // write_str updates only the value buffer, not null nor offset buffer - // This is convenient for concatenating str(s) - builder.write_str(string)?; - builder.append_value(&space_string[char_len..]); - } - } - _ => builder.append_null(), - } - } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) -} - -// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). -// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to -// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since -// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale > -// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt. -pub fn spark_decimal_div( - args: &[ColumnarValue], - data_type: &DataType, -) -> Result { - let left = &args[0]; - let right = &args[1]; - let (p3, s3) = get_precision_scale(data_type); - - let (left, right): (ArrayRef, ArrayRef) = match (left, right) { - (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (Arc::clone(l), Arc::clone(r)), - (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { - (l.to_array_of_size(r.len())?, Arc::clone(r)) - } - (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { - (Arc::clone(l), r.to_array_of_size(l.len())?) - } - (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), - }; - let left = left.as_primitive::(); - let right = right.as_primitive::(); - let (p1, s1) = get_precision_scale(left.data_type()); - let (p2, s2) = get_precision_scale(right.data_type()); - - let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32); - let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32); - let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32 - || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32 - { - let ten = BigInt::from(10); - let l_mul = ten.pow(l_exp); - let r_mul = ten.pow(r_exp); - let five = BigInt::from(5); - let zero = BigInt::from(0); - arrow::compute::kernels::arity::binary(left, right, |l, r| { - let l = BigInt::from(l) * &l_mul; - let r = BigInt::from(r) * &r_mul; - let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; - let res = if div.is_negative() { - div - &five - } else { - div + &five - } / &ten; - res.to_i128().unwrap_or(i128::MAX) - })? - } else { - let l_mul = 10_i128.pow(l_exp); - let r_mul = 10_i128.pow(r_exp); - arrow::compute::kernels::arity::binary(left, right, |l, r| { - let l = l * l_mul; - let r = r * r_mul; - let div = if r == 0 { 0 } else { l / r }; - let res = if div.is_negative() { div - 5 } else { div + 5 } / 10; - res.to_i128().unwrap_or(i128::MAX) - })? - }; - let result = result.with_data_type(DataType::Decimal128(p3, s3)); - Ok(ColumnarValue::Array(Arc::new(result))) -} - -/// Spark-compatible `isnan` expression -pub fn spark_isnan(args: &[ColumnarValue]) -> Result { - fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue { - match is_nan.nulls() { - Some(nulls) => { - let is_not_null = nulls.inner(); - ColumnarValue::Array(Arc::new(BooleanArray::new( - is_nan.values() & is_not_null, - None, - ))) - } - None => ColumnarValue::Array(Arc::new(is_nan)), - } - } - let value = &args[0]; - match value { - ColumnarValue::Array(array) => match array.data_type() { - DataType::Float64 => { - let array = array.as_any().downcast_ref::().unwrap(); - let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); - Ok(set_nulls_to_false(is_nan)) - } - DataType::Float32 => { - let array = array.as_any().downcast_ref::().unwrap(); - let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); - Ok(set_nulls_to_false(is_nan)) - } - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function isnan", - other, - ))), - }, - ColumnarValue::Scalar(a) => match a { - ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( - a.map(|x| x.is_nan()).unwrap_or(false), - )))), - ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( - a.map(|x| x.is_nan()).unwrap_or(false), - )))), - _ => Err(DataFusionError::Internal(format!( - "Unsupported data type {:?} for function isnan", - value.data_type(), - ))), - }, - } -} - -macro_rules! scalar_date_arithmetic { - ($start:expr, $days:expr, $op:expr) => {{ - let interval = IntervalDayTime::new(*$days as i32, 0); - let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); - datum::apply($start, &interval_cv, $op) - }}; -} -macro_rules! array_date_arithmetic { - ($days:expr, $interval_builder:expr, $intType:ty) => {{ - for day in $days.as_primitive::<$intType>().into_iter() { - if let Some(non_null_day) = day { - $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); - } else { - $interval_builder.append_null(); - } - } - }}; -} - -/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second -/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the -/// second argument and use DataFusion's interface to apply Arrow's operators. -fn spark_date_arithmetic( - args: &[ColumnarValue], - op: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - let start = &args[0]; - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Array(days) => { - let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); - match days.data_type() { - DataType::Int8 => { - array_date_arithmetic!(days, interval_builder, Int8Type) - } - DataType::Int16 => { - array_date_arithmetic!(days, interval_builder, Int16Type) - } - DataType::Int32 => { - array_date_arithmetic!(days, interval_builder, Int32Type) - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported data types {:?} for date arithmetic.", - args, - ))) - } - } - let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); - datum::apply(start, &interval_cv, op) - } - _ => Err(DataFusionError::Internal(format!( - "Unsupported data types {:?} for date arithmetic.", - args, - ))), - } -} -pub fn spark_date_add(args: &[ColumnarValue]) -> Result { - spark_date_arithmetic(args, add) -} - -pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { - spark_date_arithmetic(args, sub) -} diff --git a/native/spark-expr/src/scalar_funcs/hash_expressions.rs b/native/spark-expr/src/scalar_funcs/hash_expressions.rs deleted file mode 100644 index af423677b7..0000000000 --- a/native/spark-expr/src/scalar_funcs/hash_expressions.rs +++ /dev/null @@ -1,166 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::scalar_funcs::hex::hex_strings; -use crate::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; - -use arrow_array::{Array, ArrayRef, Int32Array, Int64Array, StringArray}; -use datafusion::functions::crypto::{sha224, sha256, sha384, sha512}; -use datafusion_common::cast::as_binary_array; -use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue}; -use datafusion_expr::{ColumnarValue, ScalarUDF}; -use std::sync::Arc; - -/// Spark compatible murmur3 hash (just `hash` in Spark) in vectorized execution fashion -pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { - let length = args.len(); - let seed = &args[length - 1]; - match seed { - ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { - // iterate over the arguments to find out the length of the array - let num_rows = args[0..args.len() - 1] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let mut hashes: Vec = vec![0_u32; num_rows]; - hashes.fill(*seed as u32); - let arrays = args[0..args.len() - 1] - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => Arc::clone(array), - ColumnarValue::Scalar(scalar) => { - scalar.clone().to_array_of_size(num_rows).unwrap() - } - }) - .collect::>(); - create_murmur3_hashes(&arrays, &mut hashes)?; - if num_rows == 1 { - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( - hashes[0] as i32, - )))) - } else { - let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); - Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) - } - } - _ => { - internal_err!( - "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", - seed - ) - } - } -} - -/// Spark compatible xxhash64 in vectorized execution fashion -pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result { - let length = args.len(); - let seed = &args[length - 1]; - match seed { - ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { - // iterate over the arguments to find out the length of the array - let num_rows = args[0..args.len() - 1] - .iter() - .find_map(|arg| match arg { - ColumnarValue::Array(array) => Some(array.len()), - ColumnarValue::Scalar(_) => None, - }) - .unwrap_or(1); - let mut hashes: Vec = vec![0_u64; num_rows]; - hashes.fill(*seed as u64); - let arrays = args[0..args.len() - 1] - .iter() - .map(|arg| match arg { - ColumnarValue::Array(array) => Arc::clone(array), - ColumnarValue::Scalar(scalar) => { - scalar.clone().to_array_of_size(num_rows).unwrap() - } - }) - .collect::>(); - create_xxhash64_hashes(&arrays, &mut hashes)?; - if num_rows == 1 { - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( - hashes[0] as i64, - )))) - } else { - let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); - Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) - } - } - _ => { - internal_err!( - "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", - seed - ) - } - } -} - -/// `sha224` function that simulates Spark's `sha2` expression with bit width 224 -pub fn spark_sha224(args: &[ColumnarValue]) -> Result { - wrap_digest_result_as_hex_string(args, sha224()) -} - -/// `sha256` function that simulates Spark's `sha2` expression with bit width 0 or 256 -pub fn spark_sha256(args: &[ColumnarValue]) -> Result { - wrap_digest_result_as_hex_string(args, sha256()) -} - -/// `sha384` function that simulates Spark's `sha2` expression with bit width 384 -pub fn spark_sha384(args: &[ColumnarValue]) -> Result { - wrap_digest_result_as_hex_string(args, sha384()) -} - -/// `sha512` function that simulates Spark's `sha2` expression with bit width 512 -pub fn spark_sha512(args: &[ColumnarValue]) -> Result { - wrap_digest_result_as_hex_string(args, sha512()) -} - -// Spark requires hex string as the result of sha2 functions, we have to wrap the -// result of digest functions as hex string -fn wrap_digest_result_as_hex_string( - args: &[ColumnarValue], - digest: Arc, -) -> Result { - let row_count = match &args[0] { - ColumnarValue::Array(array) => array.len(), - ColumnarValue::Scalar(_) => 1, - }; - let value = digest.invoke_batch(args, row_count)?; - match value { - ColumnarValue::Array(array) => { - let binary_array = as_binary_array(&array)?; - let string_array: StringArray = binary_array - .iter() - .map(|opt| opt.map(hex_strings::<_>)) - .collect(); - Ok(ColumnarValue::Array(Arc::new(string_array))) - } - ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( - ScalarValue::Utf8(opt.map(hex_strings::<_>)), - )), - _ => { - exec_err!( - "digest function should return binary value, but got: {:?}", - value.data_type() - ) - } - } -} diff --git a/native/spark-expr/src/schema_adapter.rs b/native/spark-expr/src/schema_adapter.rs index 161ad6f164..7bb7af6eb0 100644 --- a/native/spark-expr/src/schema_adapter.rs +++ b/native/spark-expr/src/schema_adapter.rs @@ -17,7 +17,7 @@ //! Custom schema adapter that uses Spark-compatible casts -use crate::cast::cast_supported; +use crate::conversion_funcs::cast::cast_supported; use crate::{spark_cast, SparkCastOptions}; use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions}; use arrow_schema::{Schema, SchemaRef}; diff --git a/native/spark-expr/src/spark_hash.rs b/native/spark-expr/src/spark_hash.rs deleted file mode 100644 index 1402f71715..0000000000 --- a/native/spark-expr/src/spark_hash.rs +++ /dev/null @@ -1,712 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This includes utilities for hashing and murmur3 hashing. - -use arrow::{ - compute::take, - datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, -}; -use std::sync::Arc; -use twox_hash::XxHash64; - -use datafusion::{ - arrow::{ - array::*, - datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, - Int8Type, TimeUnit, - }, - }, - error::{DataFusionError, Result}, -}; - -#[inline] -pub(crate) fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { - XxHash64::oneshot(seed, data.as_ref()) -} - -/// Spark-compatible murmur3 hash function -#[inline] -pub fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { - #[inline] - fn mix_k1(mut k1: i32) -> i32 { - k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); - k1 = k1.rotate_left(15); - k1 = k1.mul_wrapping(0x1b873593u32 as i32); - k1 - } - - #[inline] - fn mix_h1(mut h1: i32, k1: i32) -> i32 { - h1 ^= k1; - h1 = h1.rotate_left(13); - h1 = h1.mul_wrapping(5).add_wrapping(0xe6546b64u32 as i32); - h1 - } - - #[inline] - fn fmix(mut h1: i32, len: i32) -> i32 { - h1 ^= len; - h1 ^= (h1 as u32 >> 16) as i32; - h1 = h1.mul_wrapping(0x85ebca6bu32 as i32); - h1 ^= (h1 as u32 >> 13) as i32; - h1 = h1.mul_wrapping(0xc2b2ae35u32 as i32); - h1 ^= (h1 as u32 >> 16) as i32; - h1 - } - - #[inline] - unsafe fn hash_bytes_by_int(data: &[u8], seed: u32) -> i32 { - // safety: data length must be aligned to 4 bytes - let mut h1 = seed as i32; - for i in (0..data.len()).step_by(4) { - let ints = data.as_ptr().add(i) as *const i32; - let mut half_word = ints.read_unaligned(); - if cfg!(target_endian = "big") { - half_word = half_word.reverse_bits(); - } - h1 = mix_h1(h1, mix_k1(half_word)); - } - h1 - } - let data = data.as_ref(); - let len = data.len(); - let len_aligned = len - len % 4; - - // safety: - // avoid boundary checking in performance critical codes. - // all operations are guaranteed to be safe - // data is &[u8] so we do not need to check for proper alignment - unsafe { - let mut h1 = if len_aligned > 0 { - hash_bytes_by_int(&data[0..len_aligned], seed) - } else { - seed as i32 - }; - - for i in len_aligned..len { - let half_word = *data.get_unchecked(i) as i8 as i32; - h1 = mix_h1(h1, mix_k1(half_word)); - } - fmix(h1, len as i32) as u32 - } -} - -macro_rules! hash_array { - ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - if array.null_count() == 0 { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = $hash_method(&array.value(i), *hash); - } - } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = $hash_method(&array.value(i), *hash); - } - } - } - }; -} - -macro_rules! hash_array_boolean { - ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - if array.null_count() == 0 { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); - } - } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = - $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); - } - } - } - }; -} - -macro_rules! hash_array_primitive { - ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); - - if array.null_count() == 0 { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); - } - } else { - for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { - if !array.is_null(i) { - *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); - } - } - } - }; -} - -macro_rules! hash_array_primitive_float { - ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - let values = array.values(); - - if array.null_count() == 0 { - for (hash, value) in $hashes.iter_mut().zip(values.iter()) { - // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. - if *value == 0.0 && value.is_sign_negative() { - *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); - } else { - *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); - } - } - } else { - for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { - if !array.is_null(i) { - // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. - if *value == 0.0 && value.is_sign_negative() { - *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); - } else { - *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); - } - } - } - } - }; -} - -macro_rules! hash_array_decimal { - ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { - let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); - - if array.null_count() == 0 { - for (i, hash) in $hashes.iter_mut().enumerate() { - *hash = $hash_method(array.value(i).to_le_bytes(), *hash); - } - } else { - for (i, hash) in $hashes.iter_mut().enumerate() { - if !array.is_null(i) { - *hash = $hash_method(array.value(i).to_le_bytes(), *hash); - } - } - } - }; -} - -/// Hash the values in a dictionary array -fn create_hashes_dictionary( - array: &ArrayRef, - hashes_buffer: &mut [u32], - first_col: bool, -) -> Result<()> { - let dict_array = array.as_any().downcast_ref::>().unwrap(); - if !first_col { - // unpack the dictionary array as each row may have a different hash input - let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; - create_murmur3_hashes(&[unpacked], hashes_buffer)?; - } else { - // For the first column, hash each dictionary value once, and then use - // that computed hash for each key value to avoid a potentially - // expensive redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(dict_array.values()); - // same initial seed as Spark - let mut dict_hashes = vec![42; dict_values.len()]; - create_murmur3_hashes(&[dict_values], &mut dict_hashes)?; - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key.to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, - dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] - } // no update for Null, consistent with other hashes - } - } - Ok(()) -} - -// Hash the values in a dictionary array using xxhash64 -fn create_xxhash64_hashes_dictionary( - array: &ArrayRef, - hashes_buffer: &mut [u64], - first_col: bool, -) -> Result<()> { - let dict_array = array.as_any().downcast_ref::>().unwrap(); - if !first_col { - let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; - create_xxhash64_hashes(&[unpacked], hashes_buffer)?; - } else { - // Hash each dictionary value once, and then use that computed - // hash for each key value to avoid a potentially expensive - // redundant hashing for large dictionary elements (e.g. strings) - let dict_values = Arc::clone(dict_array.values()); - // same initial seed as Spark - let mut dict_hashes = vec![42u64; dict_values.len()]; - create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; - - for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { - if let Some(key) = key { - let idx = key.to_usize().ok_or_else(|| { - DataFusionError::Internal(format!( - "Can not convert key value {:?} to usize in dictionary of type {:?}", - key, - dict_array.data_type() - )) - })?; - *hash = dict_hashes[idx] - } // no update for Null, consistent with other hashes - } - } - Ok(()) -} - -/// Creates hash values for every row, based on the values in the -/// columns. -/// -/// The number of rows to hash is determined by `hashes_buffer.len()`. -/// `hashes_buffer` should be pre-sized appropriately -/// -/// `hash_method` is the hash function to use. -/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input. -macro_rules! create_hashes_internal { - ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => { - for (i, col) in $arrays.iter().enumerate() { - let first_col = i == 0; - match col.data_type() { - DataType::Boolean => { - hash_array_boolean!(BooleanArray, col, i32, $hashes_buffer, $hash_method); - } - DataType::Int8 => { - hash_array_primitive!(Int8Array, col, i32, $hashes_buffer, $hash_method); - } - DataType::Int16 => { - hash_array_primitive!(Int16Array, col, i32, $hashes_buffer, $hash_method); - } - DataType::Int32 => { - hash_array_primitive!(Int32Array, col, i32, $hashes_buffer, $hash_method); - } - DataType::Int64 => { - hash_array_primitive!(Int64Array, col, i64, $hashes_buffer, $hash_method); - } - DataType::Float32 => { - hash_array_primitive_float!( - Float32Array, - col, - f32, - i32, - $hashes_buffer, - $hash_method - ); - } - DataType::Float64 => { - hash_array_primitive_float!( - Float64Array, - col, - f64, - i64, - $hashes_buffer, - $hash_method - ); - } - DataType::Timestamp(TimeUnit::Second, _) => { - hash_array_primitive!( - TimestampSecondArray, - col, - i64, - $hashes_buffer, - $hash_method - ); - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - hash_array_primitive!( - TimestampMillisecondArray, - col, - i64, - $hashes_buffer, - $hash_method - ); - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - hash_array_primitive!( - TimestampMicrosecondArray, - col, - i64, - $hashes_buffer, - $hash_method - ); - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - hash_array_primitive!( - TimestampNanosecondArray, - col, - i64, - $hashes_buffer, - $hash_method - ); - } - DataType::Date32 => { - hash_array_primitive!(Date32Array, col, i32, $hashes_buffer, $hash_method); - } - DataType::Date64 => { - hash_array_primitive!(Date64Array, col, i64, $hashes_buffer, $hash_method); - } - DataType::Utf8 => { - hash_array!(StringArray, col, $hashes_buffer, $hash_method); - } - DataType::LargeUtf8 => { - hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method); - } - DataType::Binary => { - hash_array!(BinaryArray, col, $hashes_buffer, $hash_method); - } - DataType::LargeBinary => { - hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method); - } - DataType::FixedSizeBinary(_) => { - hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method); - } - DataType::Decimal128(_, _) => { - hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); - } - DataType::Dictionary(index_type, _) => match **index_type { - DataType::Int8 => { - $create_dictionary_hash_method::(col, $hashes_buffer, first_col)?; - } - DataType::Int16 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - DataType::Int32 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - DataType::Int64 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - DataType::UInt8 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - DataType::UInt16 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - DataType::UInt32 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - DataType::UInt64 => { - $create_dictionary_hash_method::( - col, - $hashes_buffer, - first_col, - )?; - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported dictionary type in hasher hashing: {}", - col.data_type(), - ))) - } - }, - _ => { - // This is internal because we should have caught this before. - return Err(DataFusionError::Internal(format!( - "Unsupported data type in hasher: {}", - col.data_type() - ))); - } - } - } - }; -} - -/// Creates hash values for every row, based on the values in the -/// columns. -/// -/// The number of rows to hash is determined by `hashes_buffer.len()`. -/// `hashes_buffer` should be pre-sized appropriately -pub fn create_murmur3_hashes<'a>( - arrays: &[ArrayRef], - hashes_buffer: &'a mut [u32], -) -> Result<&'a mut [u32]> { - create_hashes_internal!( - arrays, - hashes_buffer, - spark_compatible_murmur3_hash, - create_hashes_dictionary - ); - Ok(hashes_buffer) -} - -/// Creates xxhash64 hash values for every row, based on the values in the -/// columns. -/// -/// The number of rows to hash is determined by `hashes_buffer.len()`. -/// `hashes_buffer` should be pre-sized appropriately -pub fn create_xxhash64_hashes<'a>( - arrays: &[ArrayRef], - hashes_buffer: &'a mut [u64], -) -> Result<&'a mut [u64]> { - create_hashes_internal!( - arrays, - hashes_buffer, - spark_compatible_xxhash64, - create_xxhash64_hashes_dictionary - ); - Ok(hashes_buffer) -} - -#[cfg(test)] -mod tests { - use arrow::array::{Float32Array, Float64Array}; - use std::sync::Arc; - - use super::{create_murmur3_hashes, create_xxhash64_hashes}; - use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; - - macro_rules! test_hashes_internal { - ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => { - let i = $input; - let mut hashes = $initial_seeds.clone(); - $hash_method(&[i], &mut hashes).unwrap(); - assert_eq!(hashes, $expected); - }; - } - - macro_rules! test_hashes_with_nulls { - ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => { - // copied before inserting nulls - let mut input_with_nulls = $values.clone(); - let mut expected_with_nulls = $expected.clone(); - // test before inserting nulls - let len = $values.len(); - let initial_seeds = vec![42 as $seed_type; len]; - let i = Arc::new(<$t>::from($values)) as ArrayRef; - test_hashes_internal!($method, i, initial_seeds, $expected); - - // test with nulls - let median = len / 2; - input_with_nulls.insert(0, None); - input_with_nulls.insert(median, None); - expected_with_nulls.insert(0, 42 as $seed_type); - expected_with_nulls.insert(median, 42 as $seed_type); - let len_with_nulls = len + 2; - let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls]; - let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef; - test_hashes_internal!( - $method, - nullable_input, - initial_seeds_with_nulls, - expected_with_nulls - ); - }; - } - - fn test_murmur3_hash>> + 'static>( - values: Vec>, - expected: Vec, - ) { - test_hashes_with_nulls!(create_murmur3_hashes, T, values, expected, u32); - } - - fn test_xxhash64_hash>> + 'static>( - values: Vec>, - expected: Vec, - ) { - test_hashes_with_nulls!(create_xxhash64_hashes, T, values, expected, u64); - } - - #[test] - fn test_i8() { - test_murmur3_hash::( - vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], - ); - test_xxhash64_hash::( - vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], - vec![ - 0xa309b38455455929, - 0x3229fbc4681e48f3, - 0x1bfdda8861c06e45, - 0x77cc15d9f9f2cdc2, - 0x39bc22b9e94d81d0, - ], - ); - } - - #[test] - fn test_i32() { - test_murmur3_hash::( - vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], - vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], - ); - test_xxhash64_hash::( - vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], - vec![ - 0xa309b38455455929, - 0x3229fbc4681e48f3, - 0x1bfdda8861c06e45, - 0x14f0ac009c21721c, - 0x1cc7cb8d034769cd, - ], - ); - } - - #[test] - fn test_i64() { - test_murmur3_hash::( - vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], - vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], - ); - test_xxhash64_hash::( - vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], - vec![ - 0x9ed50fd59358d232, - 0xb71b47ebda15746c, - 0x358ae035bfb46fd2, - 0xd2f1c616ae7eb306, - 0x88608019c494c1f4, - ], - ); - } - - #[test] - fn test_f32() { - test_murmur3_hash::( - vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - Some(99999999999.99999999999), - Some(-99999999999.99999999999), - ], - vec![ - 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, - ], - ); - test_xxhash64_hash::( - vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - Some(99999999999.99999999999), - Some(-99999999999.99999999999), - ], - vec![ - 0x9b92689757fcdbd, - 0x3229fbc4681e48f3, - 0x3229fbc4681e48f3, - 0xa2becc0e61bb3823, - 0x8f20ab82d4f3687f, - 0xdce4982d97f7ac4, - ], - ) - } - - #[test] - fn test_f64() { - test_murmur3_hash::( - vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - Some(99999999999.99999999999), - Some(-99999999999.99999999999), - ], - vec![ - 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, - ], - ); - - test_xxhash64_hash::( - vec![ - Some(1.0), - Some(0.0), - Some(-0.0), - Some(-1.0), - Some(99999999999.99999999999), - Some(-99999999999.99999999999), - ], - vec![ - 0xe1fd6e07fee8ad53, - 0xb71b47ebda15746c, - 0xb71b47ebda15746c, - 0x8cdde022746f8f1f, - 0x793c5c88d313eac7, - 0xc5e60e7b75d9b232, - ], - ) - } - - #[test] - fn test_str() { - let input = [ - "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", - ] - .iter() - .map(|s| Some(s.to_string())) - .collect::>>(); - let expected: Vec = vec![ - 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, - 1322437556, 0xe860e5cc, 814637928, - ]; - - test_murmur3_hash::(input.clone(), expected); - test_xxhash64_hash::( - input, - vec![ - 0xc3629e6318d53932, - 0xe7097b6a54378d8a, - 0x98b1582b0977e704, - 0xa80d9d5a6a523bd5, - 0xfcba5f61ac666c61, - 0x88e4fe59adf7b0cc, - 0x259dd873209a3fe3, - 0x13c1d910702770e6, - 0xa17b5eb5dc364dff, - 0xf241303e4a90f299, - ], - ) - } -} diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs new file mode 100644 index 0000000000..fff6134dab --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod read_side_padding; + +pub use read_side_padding::spark_read_side_padding; diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs new file mode 100644 index 0000000000..15807bf57d --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow_array::builder::GenericStringBuilder; +use arrow_array::Array; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use std::fmt::Write; +use std::sync::Arc; + +/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length +pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { + match array.data_type() { + DataType::Utf8 => spark_read_side_padding_internal::(array, *length), + DataType::LargeUtf8 => spark_read_side_padding_internal::(array, *length), + // TODO: handle Dictionary types + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function read_side_padding", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for function read_side_padding", + ))), + } +} + +fn spark_read_side_padding_internal( + array: &ArrayRef, + length: i32, +) -> Result { + let string_array = as_generic_string_array::(array)?; + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); + + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + builder.append_value(string); + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); + } + } + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs new file mode 100644 index 0000000000..2a5351bb7f --- /dev/null +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod char_varchar_utils; + +pub use char_varchar_utils::*; diff --git a/native/spark-expr/src/scalar_funcs/chr.rs b/native/spark-expr/src/string_funcs/chr.rs similarity index 100% rename from native/spark-expr/src/scalar_funcs/chr.rs rename to native/spark-expr/src/string_funcs/chr.rs diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs new file mode 100644 index 0000000000..d56b5662c3 --- /dev/null +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod chr; +mod prediction; +mod string_space; +mod substring; + +pub use chr::SparkChrFunc; +pub use prediction::*; +pub use string_space::StringSpaceExpr; +pub use substring::SubstringExpr; diff --git a/native/spark-expr/src/strings.rs b/native/spark-expr/src/string_funcs/prediction.rs similarity index 54% rename from native/spark-expr/src/strings.rs rename to native/spark-expr/src/string_funcs/prediction.rs index c2706b5896..d2ef82fcbe 100644 --- a/native/spark-expr/src/strings.rs +++ b/native/spark-expr/src/string_funcs/prediction.rs @@ -17,7 +17,6 @@ #![allow(deprecated)] -use crate::kernels::strings::{string_space, substring}; use arrow::{ compute::{ contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, like_dyn, @@ -136,155 +135,3 @@ make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dy make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn); make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn); - -#[derive(Debug, Eq)] -pub struct SubstringExpr { - pub child: Arc, - pub start: i64, - pub len: u64, -} - -impl Hash for SubstringExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.start.hash(state); - self.len.hash(state); - } -} - -impl PartialEq for SubstringExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.start.eq(&other.start) && self.len.eq(&other.len) - } -} -#[derive(Debug, Eq)] -pub struct StringSpaceExpr { - pub child: Arc, -} - -impl Hash for StringSpaceExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - } -} - -impl PartialEq for StringSpaceExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - } -} - -impl SubstringExpr { - pub fn new(child: Arc, start: i64, len: u64) -> Self { - Self { child, start, len } - } -} - -impl StringSpaceExpr { - pub fn new(child: Arc) -> Self { - Self { child } - } -} - -impl Display for SubstringExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "StringSpace [start: {}, len: {}, child: {}]", - self.start, self.len, self.child - ) - } -} - -impl Display for StringSpaceExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "StringSpace [child: {}] ", self.child) - } -} - -impl PhysicalExpr for SubstringExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - self.child.data_type(input_schema) - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let result = substring(&array, self.start, self.len)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Substring(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - Ok(Arc::new(SubstringExpr::new( - Arc::clone(&children[0]), - self.start, - self.len, - ))) - } -} - -impl PhysicalExpr for StringSpaceExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema)? { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Utf8))) - } - _ => Ok(DataType::Utf8), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let result = string_space(&array)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "StringSpace(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - Ok(Arc::new(StringSpaceExpr::new(Arc::clone(&children[0])))) - } -} diff --git a/native/spark-expr/src/kernels/strings.rs b/native/spark-expr/src/string_funcs/string_space.rs similarity index 52% rename from native/spark-expr/src/kernels/strings.rs rename to native/spark-expr/src/string_funcs/string_space.rs index bb275fbb9f..7c9885738e 100644 --- a/native/spark-expr/src/kernels/strings.rs +++ b/native/spark-expr/src/string_funcs/string_space.rs @@ -15,24 +15,106 @@ // specific language governing permissions and limitations // under the License. -//! String kernels +#![allow(deprecated)] -use std::sync::Arc; - -use arrow::{ - array::*, - buffer::MutableBuffer, - compute::kernels::substring::{substring as arrow_substring, substring_by_char}, - datatypes::{DataType, Int32Type}, +use arrow::record_batch::RecordBatch; +use arrow_array::cast::as_dictionary_array; +use arrow_array::types::Int32Type; +use arrow_array::{ + make_array, Array, ArrayRef, DictionaryArray, GenericStringArray, Int32Array, OffsetSizeTrait, }; +use arrow_buffer::MutableBuffer; +use arrow_data::ArrayData; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct StringSpaceExpr { + pub child: Arc, +} + +impl Hash for StringSpaceExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + } +} + +impl PartialEq for StringSpaceExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + } +} + +impl StringSpaceExpr { + pub fn new(child: Arc) -> Self { + Self { child } + } +} + +impl Display for StringSpaceExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StringSpace [child: {}] ", self.child) + } +} + +impl PhysicalExpr for StringSpaceExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema)? { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Utf8))) + } + _ => Ok(DataType::Utf8), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let result = string_space_kernel(&array)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "StringSpace(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(StringSpaceExpr::new(Arc::clone(&children[0])))) + } +} /// Returns an ArrayRef with a string consisting of `length` spaces. /// /// # Preconditions /// /// - elements in `length` must not be negative -pub fn string_space(length: &dyn Array) -> Result { +pub fn string_space_kernel(length: &dyn Array) -> Result { match length.data_type() { DataType::Int32 => { let array = length.as_any().downcast_ref::().unwrap(); @@ -40,7 +122,7 @@ pub fn string_space(length: &dyn Array) -> Result { } DataType::Dictionary(_, _) => { let dict = as_dictionary_array::(length); - let values = string_space(dict.values())?; + let values = string_space_kernel(dict.values())?; let result = DictionaryArray::try_new(dict.keys().clone(), values)?; Ok(Arc::new(result)) } @@ -51,41 +133,6 @@ pub fn string_space(length: &dyn Array) -> Result { } } -pub fn substring(array: &dyn Array, start: i64, length: u64) -> Result { - match array.data_type() { - DataType::LargeUtf8 => substring_by_char( - array - .as_any() - .downcast_ref::() - .expect("A large string is expected"), - start, - Some(length), - ) - .map_err(|e| e.into()) - .map(|t| make_array(t.into_data())), - DataType::Utf8 => substring_by_char( - array - .as_any() - .downcast_ref::() - .expect("A string is expected"), - start, - Some(length), - ) - .map_err(|e| e.into()) - .map(|t| make_array(t.into_data())), - DataType::Binary | DataType::LargeBinary => { - arrow_substring(array, start, Some(length)).map_err(|e| e.into()) - } - DataType::Dictionary(_, _) => { - let dict = as_dictionary_array::(array); - let values = substring(dict.values(), start, length)?; - let result = DictionaryArray::try_new(dict.keys().clone(), values)?; - Ok(Arc::new(result)) - } - dt => panic!("Unsupported input type for function 'substring': {:?}", dt), - } -} - fn generic_string_space(length: &Int32Array) -> ArrayRef { let array_len = length.len(); let mut offsets = MutableBuffer::new((array_len + 1) * std::mem::size_of::()); diff --git a/native/spark-expr/src/string_funcs/substring.rs b/native/spark-expr/src/string_funcs/substring.rs new file mode 100644 index 0000000000..2563faec7c --- /dev/null +++ b/native/spark-expr/src/string_funcs/substring.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![allow(deprecated)] + +use arrow::record_batch::RecordBatch; +use arrow_array::cast::as_dictionary_array; +use arrow_array::{make_array, Array, ArrayRef, DictionaryArray, LargeStringArray, StringArray}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +use arrow::{ + compute::kernels::substring::{substring as arrow_substring, substring_by_char}, + datatypes::{DataType, Int32Type, Schema}, +}; + +#[derive(Debug, Eq)] +pub struct SubstringExpr { + pub child: Arc, + pub start: i64, + pub len: u64, +} + +impl Hash for SubstringExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.start.hash(state); + self.len.hash(state); + } +} + +impl PartialEq for SubstringExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.start.eq(&other.start) && self.len.eq(&other.len) + } +} + +impl SubstringExpr { + pub fn new(child: Arc, start: i64, len: u64) -> Self { + Self { child, start, len } + } +} + +impl Display for SubstringExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "StringSpace [start: {}, len: {}, child: {}]", + self.start, self.len, self.child + ) + } +} + +impl PhysicalExpr for SubstringExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let result = substring_kernel(&array, self.start, self.len)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Substring(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(SubstringExpr::new( + Arc::clone(&children[0]), + self.start, + self.len, + ))) + } +} + +pub fn substring_kernel( + array: &dyn Array, + start: i64, + length: u64, +) -> Result { + match array.data_type() { + DataType::LargeUtf8 => substring_by_char( + array + .as_any() + .downcast_ref::() + .expect("A large string is expected"), + start, + Some(length), + ) + .map_err(|e| e.into()) + .map(|t| make_array(t.into_data())), + DataType::Utf8 => substring_by_char( + array + .as_any() + .downcast_ref::() + .expect("A string is expected"), + start, + Some(length), + ) + .map_err(|e| e.into()) + .map(|t| make_array(t.into_data())), + DataType::Binary | DataType::LargeBinary => { + arrow_substring(array, start, Some(length)).map_err(|e| e.into()) + } + DataType::Dictionary(_, _) => { + let dict = as_dictionary_array::(array); + let values = substring_kernel(dict.values(), start, length)?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(Arc::new(result)) + } + dt => panic!("Unsupported input type for function 'substring': {:?}", dt), + } +} diff --git a/native/spark-expr/src/structs.rs b/native/spark-expr/src/struct_funcs/create_named_struct.rs similarity index 64% rename from native/spark-expr/src/structs.rs rename to native/spark-expr/src/struct_funcs/create_named_struct.rs index 7cc49e4281..df63127412 100644 --- a/native/spark-expr/src/structs.rs +++ b/native/spark-expr/src/struct_funcs/create_named_struct.rs @@ -16,10 +16,10 @@ // under the License. use arrow::record_batch::RecordBatch; -use arrow_array::{Array, StructArray}; +use arrow_array::StructArray; use arrow_schema::{DataType, Field, Schema}; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion_common::Result as DataFusionResult; use datafusion_physical_expr::PhysicalExpr; use std::{ any::Any, @@ -106,102 +106,6 @@ impl Display for CreateNamedStruct { } } -#[derive(Debug, Eq)] -pub struct GetStructField { - child: Arc, - ordinal: usize, -} - -impl Hash for GetStructField { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.ordinal.hash(state); - } -} -impl PartialEq for GetStructField { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) - } -} - -impl GetStructField { - pub fn new(child: Arc, ordinal: usize) -> Self { - Self { child, ordinal } - } - - fn child_field(&self, input_schema: &Schema) -> DataFusionResult> { - match self.child.data_type(input_schema)? { - DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), - data_type => Err(DataFusionError::Plan(format!( - "Expect struct field, got {:?}", - data_type - ))), - } - } -} - -impl PhysicalExpr for GetStructField { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.child_field(input_schema)?.data_type().clone()) - } - - fn nullable(&self, input_schema: &Schema) -> DataFusionResult { - Ok(self.child_field(input_schema)?.is_nullable()) - } - - fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { - let child_value = self.child.evaluate(batch)?; - - match child_value { - ColumnarValue::Array(array) => { - let struct_array = array - .as_any() - .downcast_ref::() - .expect("A struct is expected"); - - Ok(ColumnarValue::Array(Arc::clone( - struct_array.column(self.ordinal), - ))) - } - ColumnarValue::Scalar(ScalarValue::Struct(struct_array)) => Ok(ColumnarValue::Array( - Arc::clone(struct_array.column(self.ordinal)), - )), - value => Err(DataFusionError::Execution(format!( - "Expected a struct array, got {:?}", - value - ))), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> datafusion_common::Result> { - Ok(Arc::new(GetStructField::new( - Arc::clone(&children[0]), - self.ordinal, - ))) - } -} - -impl Display for GetStructField { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "GetStructField [child: {:?}, ordinal: {:?}]", - self.child, self.ordinal - ) - } -} - #[cfg(test)] mod test { use super::CreateNamedStruct; diff --git a/native/spark-expr/src/struct_funcs/get_struct_field.rs b/native/spark-expr/src/struct_funcs/get_struct_field.rs new file mode 100644 index 0000000000..c4e1a1e239 --- /dev/null +++ b/native/spark-expr/src/struct_funcs/get_struct_field.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, StructArray}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct GetStructField { + child: Arc, + ordinal: usize, +} + +impl Hash for GetStructField { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + } +} +impl PartialEq for GetStructField { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) + } +} + +impl GetStructField { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult> { + match self.child.data_type(input_schema)? { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Plan(format!( + "Expect struct field, got {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetStructField { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?; + + match child_value { + ColumnarValue::Array(array) => { + let struct_array = array + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + Ok(ColumnarValue::Array(Arc::clone( + struct_array.column(self.ordinal), + ))) + } + ColumnarValue::Scalar(ScalarValue::Struct(struct_array)) => Ok(ColumnarValue::Array( + Arc::clone(struct_array.column(self.ordinal)), + )), + value => Err(DataFusionError::Execution(format!( + "Expected a struct array, got {:?}", + value + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(GetStructField::new( + Arc::clone(&children[0]), + self.ordinal, + ))) + } +} + +impl Display for GetStructField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetStructField [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} diff --git a/native/spark-expr/src/struct_funcs/mod.rs b/native/spark-expr/src/struct_funcs/mod.rs new file mode 100644 index 0000000000..86edcceac9 --- /dev/null +++ b/native/spark-expr/src/struct_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod create_named_struct; +mod get_struct_field; + +pub use create_named_struct::CreateNamedStruct; +pub use get_struct_field::GetStructField; diff --git a/native/spark-expr/src/temporal.rs b/native/spark-expr/src/temporal.rs deleted file mode 100644 index fb549f9ce8..0000000000 --- a/native/spark-expr/src/temporal.rs +++ /dev/null @@ -1,510 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::utils::array_with_timezone; -use arrow::{ - compute::{date_part, DatePart}, - record_batch::RecordBatch, -}; -use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; -use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, ScalarValue::Utf8}; -use datafusion_physical_expr::PhysicalExpr; -use std::hash::Hash; -use std::{ - any::Any, - fmt::{Debug, Display, Formatter}, - sync::Arc, -}; - -use crate::kernels::temporal::{ - date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, -}; - -#[derive(Debug, Eq)] -pub struct HourExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for HourExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for HourExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl HourExpr { - pub fn new(child: Arc, timezone: String) -> Self { - HourExpr { child, timezone } - } -} - -impl Display for HourExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Hour [timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for HourExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Hour)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Hour(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(HourExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct MinuteExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for MinuteExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for MinuteExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl MinuteExpr { - pub fn new(child: Arc, timezone: String) -> Self { - MinuteExpr { child, timezone } - } -} - -impl Display for MinuteExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Minute [timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for MinuteExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Minute)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Minute(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(MinuteExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct SecondExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for SecondExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for SecondExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl SecondExpr { - pub fn new(child: Arc, timezone: String) -> Self { - SecondExpr { child, timezone } - } -} - -impl Display for SecondExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Second (timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for SecondExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Second)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Second(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(SecondExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct DateTruncExpr { - /// An array with DataType::Date32 - child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc - format: Arc, -} - -impl Hash for DateTruncExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - } -} -impl PartialEq for DateTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.format.eq(&other.format) - } -} - -impl DateTruncExpr { - pub fn new(child: Arc, format: Arc) -> Self { - DateTruncExpr { child, format } - } -} - -impl Display for DateTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "DateTrunc [child:{}, format: {}]", - self.child, self.format - ) - } -} - -impl PhysicalExpr for DateTruncExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - self.child.data_type(input_schema) - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let date = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; - match (date, format) { - (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let result = date_trunc_dyn(&date, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { - let result = date_trunc_array_fmt_dyn(&date, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ - (PrimitiveArray, StringArray)".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(DateTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct TimestampTruncExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc - format: Arc, - /// String containing a timezone name. The name must be found in the standard timezone - /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is - /// later parsed into a chrono::TimeZone. - /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) - /// along with a single value for the associated TimeZone. The timezone offset is applied - /// just before any operations on the timestamp - timezone: String, -} - -impl Hash for TimestampTruncExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for TimestampTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - && self.format.eq(&other.format) - && self.timezone.eq(&other.timezone) - } -} - -impl TimestampTruncExpr { - pub fn new( - child: Arc, - format: Arc, - timezone: String, - ) -> Self { - TimestampTruncExpr { - child, - format, - timezone, - } - } -} - -impl Display for TimestampTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "TimestampTrunc [child:{}, format:{}, timezone: {}]", - self.child, self.format, self.timezone - ) - } -} - -impl PhysicalExpr for TimestampTruncExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema)? { - DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( - key_type, - Box::new(DataType::Timestamp(Microsecond, None)), - )), - _ => Ok(DataType::Timestamp(Microsecond, None)), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let timestamp = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; - let tz = self.timezone.clone(); - match (timestamp, format) { - (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let ts = array_with_timezone( - ts, - tz.clone(), - Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), - )?; - let result = timestamp_trunc_dyn(&ts, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { - let ts = array_with_timezone( - ts, - tz.clone(), - Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), - )?; - let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function TimestampTrunc. \ - Expected (PrimitiveArray, Scalar, String) or \ - (PrimitiveArray, StringArray, String)" - .to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(TimestampTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - self.timezone.clone(), - ))) - } -}