From 48e3f94021d572032f5e4a456ed6bb2497790106 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Jan 2025 11:55:08 +0200 Subject: [PATCH] extract datetime_funcs expressions to folders based on spark grouping --- native/spark-expr/src/comet_scalar_funcs.rs | 7 +- .../src/datetime_funcs/date_arithmetic.rs | 102 ++++ .../src/datetime_funcs/date_trunc.rs | 113 ++++ native/spark-expr/src/datetime_funcs/hour.rs | 122 +++++ .../spark-expr/src/datetime_funcs/minute.rs | 122 +++++ native/spark-expr/src/datetime_funcs/mod.rs | 30 ++ .../spark-expr/src/datetime_funcs/second.rs | 122 +++++ .../src/datetime_funcs/timestamp_trunc.rs | 152 ++++++ native/spark-expr/src/lib.rs | 5 +- native/spark-expr/src/scalar_funcs.rs | 83 +-- native/spark-expr/src/temporal.rs | 510 ------------------ 11 files changed, 773 insertions(+), 595 deletions(-) create mode 100644 native/spark-expr/src/datetime_funcs/date_arithmetic.rs create mode 100644 native/spark-expr/src/datetime_funcs/date_trunc.rs create mode 100644 native/spark-expr/src/datetime_funcs/hour.rs create mode 100644 native/spark-expr/src/datetime_funcs/minute.rs create mode 100644 native/spark-expr/src/datetime_funcs/mod.rs create mode 100644 native/spark-expr/src/datetime_funcs/second.rs create mode 100644 native/spark-expr/src/datetime_funcs/timestamp_trunc.rs delete mode 100644 native/spark-expr/src/temporal.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 71ff0e9dcc..ece1c46b08 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::datetime_funcs::*; use crate::scalar_funcs::hash_expressions::{ spark_sha224, spark_sha256, spark_sha384, spark_sha512, }; use crate::scalar_funcs::{ - spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex, - spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round, - spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc, + spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal, + spark_murmur3_hash, spark_read_side_padding, spark_round, spark_unhex, spark_unscaled_value, + spark_xxhash64, SparkChrFunc, }; use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result as DataFusionResult}; diff --git a/native/spark-expr/src/datetime_funcs/date_arithmetic.rs b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs new file mode 100644 index 0000000000..cc4da9af70 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/date_arithmetic.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::kernels::numeric::{add, sub}; +use arrow::datatypes::IntervalDayTime; +use arrow_array::builder::IntervalDayTimeBuilder; +use arrow_array::types::{Int16Type, Int32Type, Int8Type}; +use arrow_array::{Array, Datum}; +use arrow_schema::{ArrowError, DataType}; +use datafusion::physical_expr_common::datum; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::sync::Arc; + +macro_rules! scalar_date_arithmetic { + ($start:expr, $days:expr, $op:expr) => {{ + let interval = IntervalDayTime::new(*$days as i32, 0); + let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); + datum::apply($start, &interval_cv, $op) + }}; +} +macro_rules! array_date_arithmetic { + ($days:expr, $interval_builder:expr, $intType:ty) => {{ + for day in $days.as_primitive::<$intType>().into_iter() { + if let Some(non_null_day) = day { + $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); + } else { + $interval_builder.append_null(); + } + } + }}; +} + +/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second +/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the +/// second argument and use DataFusion's interface to apply Arrow's operators. +fn spark_date_arithmetic( + args: &[ColumnarValue], + op: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + let start = &args[0]; + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Array(days) => { + let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); + match days.data_type() { + DataType::Int8 => { + array_date_arithmetic!(days, interval_builder, Int8Type) + } + DataType::Int16 => { + array_date_arithmetic!(days, interval_builder, Int16Type) + } + DataType::Int32 => { + array_date_arithmetic!(days, interval_builder, Int32Type) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))) + } + } + let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); + datum::apply(start, &interval_cv, op) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))), + } +} + +pub fn spark_date_add(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, add) +} + +pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, sub) +} diff --git a/native/spark-expr/src/datetime_funcs/date_trunc.rs b/native/spark-expr/src/datetime_funcs/date_trunc.rs new file mode 100644 index 0000000000..5c044945d0 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/date_trunc.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn}; + +#[derive(Debug, Eq)] +pub struct DateTruncExpr { + /// An array with DataType::Date32 + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc + format: Arc, +} + +impl Hash for DateTruncExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.format.hash(state); + } +} +impl PartialEq for DateTruncExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.format.eq(&other.format) + } +} + +impl DateTruncExpr { + pub fn new(child: Arc, format: Arc) -> Self { + DateTruncExpr { child, format } + } +} + +impl Display for DateTruncExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DateTrunc [child:{}, format: {}]", + self.child, self.format + ) + } +} + +impl PhysicalExpr for DateTruncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let date = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + match (date, format) { + (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let result = date_trunc_dyn(&date, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { + let result = date_trunc_array_fmt_dyn(&date, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ + (PrimitiveArray, StringArray)".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(DateTruncExpr::new( + Arc::clone(&children[0]), + Arc::clone(&self.format), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/hour.rs b/native/spark-expr/src/datetime_funcs/hour.rs new file mode 100644 index 0000000000..faf9529a51 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/hour.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct HourExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for HourExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for HourExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl HourExpr { + pub fn new(child: Arc, timezone: String) -> Self { + HourExpr { child, timezone } + } +} + +impl Display for HourExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Hour [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for HourExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Hour)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Hour(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(HourExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/minute.rs b/native/spark-expr/src/datetime_funcs/minute.rs new file mode 100644 index 0000000000..b7facc1673 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/minute.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct MinuteExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for MinuteExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for MinuteExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl MinuteExpr { + pub fn new(child: Arc, timezone: String) -> Self { + MinuteExpr { child, timezone } + } +} + +impl Display for MinuteExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Minute [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for MinuteExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Minute)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Minute(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(MinuteExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/mod.rs b/native/spark-expr/src/datetime_funcs/mod.rs new file mode 100644 index 0000000000..1f4d427282 --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod date_arithmetic; +mod date_trunc; +mod hour; +mod minute; +mod second; +mod timestamp_trunc; + +pub use date_arithmetic::{spark_date_add, spark_date_sub}; +pub use date_trunc::DateTruncExpr; +pub use hour::HourExpr; +pub use minute::MinuteExpr; +pub use second::SecondExpr; +pub use timestamp_trunc::TimestampTruncExpr; diff --git a/native/spark-expr/src/datetime_funcs/second.rs b/native/spark-expr/src/datetime_funcs/second.rs new file mode 100644 index 0000000000..76a4dd9a2c --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/second.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct SecondExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for SecondExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for SecondExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl SecondExpr { + pub fn new(child: Arc, timezone: String) -> Self { + SecondExpr { child, timezone } + } +} + +impl Display for SecondExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Second (timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for SecondExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Second)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Second(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(SecondExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs new file mode 100644 index 0000000000..349992322f --- /dev/null +++ b/native/spark-expr/src/datetime_funcs/timestamp_trunc.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn}; + +#[derive(Debug, Eq)] +pub struct TimestampTruncExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc + format: Arc, + /// String containing a timezone name. The name must be found in the standard timezone + /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is + /// later parsed into a chrono::TimeZone. + /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) + /// along with a single value for the associated TimeZone. The timezone offset is applied + /// just before any operations on the timestamp + timezone: String, +} + +impl Hash for TimestampTruncExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.format.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for TimestampTruncExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.format.eq(&other.format) + && self.timezone.eq(&other.timezone) + } +} + +impl TimestampTruncExpr { + pub fn new( + child: Arc, + format: Arc, + timezone: String, + ) -> Self { + TimestampTruncExpr { + child, + format, + timezone, + } + } +} + +impl Display for TimestampTruncExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TimestampTrunc [child:{}, format:{}, timezone: {}]", + self.child, self.format, self.timezone + ) + } +} + +impl PhysicalExpr for TimestampTruncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema)? { + DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( + key_type, + Box::new(DataType::Timestamp(Microsecond, None)), + )), + _ => Ok(DataType::Timestamp(Microsecond, None)), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let timestamp = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + let tz = self.timezone.clone(); + match (timestamp, format) { + (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_dyn(&ts, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function TimestampTrunc. \ + Expected (PrimitiveArray, Scalar, String) or \ + (PrimitiveArray, StringArray, String)" + .to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(TimestampTruncExpr::new( + Arc::clone(&children[0]), + Arc::clone(&self.format), + self.timezone.clone(), + ))) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f358731004..6d7bb7c8db 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -53,7 +53,6 @@ pub use sum_decimal::SumDecimal; mod negative; pub use negative::{create_negate_expr, NegativeExpr}; mod normalize_nan; -mod temporal; pub mod test_common; pub mod timezone; @@ -66,14 +65,16 @@ pub use normalize_nan::NormalizeNaNAndZero; mod variance; pub use variance::Variance; mod comet_scalar_funcs; +mod datetime_funcs; + pub use cast::{spark_cast, Cast, SparkCastOptions}; pub use comet_scalar_funcs::create_comet_physical_fun; +pub use datetime_funcs::*; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use list::{ArrayInsert, GetArrayStructFields, ListExtract}; pub use regexp::RLike; pub use structs::{CreateNamedStruct, GetStructField}; -pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; pub use to_json::ToJson; /// Spark supports three evaluation modes when evaluating expressions, which affect diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs index 2961f038dc..6b87dc96ad 100644 --- a/native/spark-expr/src/scalar_funcs.rs +++ b/native/spark-expr/src/scalar_funcs.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -use arrow::compute::kernels::numeric::{add, sub}; -use arrow::datatypes::IntervalDayTime; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array, @@ -24,11 +22,9 @@ use arrow::{ }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder}; -use arrow_array::types::{Int16Type, Int32Type, Int8Type}; -use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array}; -use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION}; -use datafusion::physical_expr_common::datum; +use arrow_array::builder::GenericStringBuilder; +use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; +use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; use datafusion_common::{ cast::as_generic_string_array, exec_err, internal_err, DataFusionError, @@ -551,76 +547,3 @@ pub fn spark_isnan(args: &[ColumnarValue]) -> Result {{ - let interval = IntervalDayTime::new(*$days as i32, 0); - let interval_cv = ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); - datum::apply($start, &interval_cv, $op) - }}; -} -macro_rules! array_date_arithmetic { - ($days:expr, $interval_builder:expr, $intType:ty) => {{ - for day in $days.as_primitive::<$intType>().into_iter() { - if let Some(non_null_day) = day { - $interval_builder.append_value(IntervalDayTime::new(non_null_day as i32, 0)); - } else { - $interval_builder.append_null(); - } - } - }}; -} - -/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second -/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the -/// second argument and use DataFusion's interface to apply Arrow's operators. -fn spark_date_arithmetic( - args: &[ColumnarValue], - op: impl Fn(&dyn Datum, &dyn Datum) -> Result, -) -> Result { - let start = &args[0]; - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { - scalar_date_arithmetic!(start, days, op) - } - ColumnarValue::Array(days) => { - let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); - match days.data_type() { - DataType::Int8 => { - array_date_arithmetic!(days, interval_builder, Int8Type) - } - DataType::Int16 => { - array_date_arithmetic!(days, interval_builder, Int16Type) - } - DataType::Int32 => { - array_date_arithmetic!(days, interval_builder, Int32Type) - } - _ => { - return Err(DataFusionError::Internal(format!( - "Unsupported data types {:?} for date arithmetic.", - args, - ))) - } - } - let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); - datum::apply(start, &interval_cv, op) - } - _ => Err(DataFusionError::Internal(format!( - "Unsupported data types {:?} for date arithmetic.", - args, - ))), - } -} -pub fn spark_date_add(args: &[ColumnarValue]) -> Result { - spark_date_arithmetic(args, add) -} - -pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { - spark_date_arithmetic(args, sub) -} diff --git a/native/spark-expr/src/temporal.rs b/native/spark-expr/src/temporal.rs deleted file mode 100644 index fb549f9ce8..0000000000 --- a/native/spark-expr/src/temporal.rs +++ /dev/null @@ -1,510 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use crate::utils::array_with_timezone; -use arrow::{ - compute::{date_part, DatePart}, - record_batch::RecordBatch, -}; -use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; -use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{DataFusionError, ScalarValue::Utf8}; -use datafusion_physical_expr::PhysicalExpr; -use std::hash::Hash; -use std::{ - any::Any, - fmt::{Debug, Display, Formatter}, - sync::Arc, -}; - -use crate::kernels::temporal::{ - date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, -}; - -#[derive(Debug, Eq)] -pub struct HourExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for HourExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for HourExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl HourExpr { - pub fn new(child: Arc, timezone: String) -> Self { - HourExpr { child, timezone } - } -} - -impl Display for HourExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Hour [timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for HourExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Hour)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Hour(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(HourExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct MinuteExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for MinuteExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for MinuteExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl MinuteExpr { - pub fn new(child: Arc, timezone: String) -> Self { - MinuteExpr { child, timezone } - } -} - -impl Display for MinuteExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Minute [timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for MinuteExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Minute)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Minute(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(MinuteExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct SecondExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - timezone: String, -} - -impl Hash for SecondExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for SecondExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.timezone.eq(&other.timezone) - } -} - -impl SecondExpr { - pub fn new(child: Arc, timezone: String) -> Self { - SecondExpr { child, timezone } - } -} - -impl Display for SecondExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "Second (timezone:{}, child: {}]", - self.timezone, self.child - ) - } -} - -impl PhysicalExpr for SecondExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema).unwrap() { - DataType::Dictionary(key_type, _) => { - Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) - } - _ => Ok(DataType::Int32), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => { - let array = array_with_timezone( - array, - self.timezone.clone(), - Some(&DataType::Timestamp( - Microsecond, - Some(self.timezone.clone().into()), - )), - )?; - let result = date_part(&array, DatePart::Second)?; - - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Second(scalar) should be fold in Spark JVM side.".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(SecondExpr::new( - Arc::clone(&children[0]), - self.timezone.clone(), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct DateTruncExpr { - /// An array with DataType::Date32 - child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc - format: Arc, -} - -impl Hash for DateTruncExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - } -} -impl PartialEq for DateTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) && self.format.eq(&other.format) - } -} - -impl DateTruncExpr { - pub fn new(child: Arc, format: Arc) -> Self { - DateTruncExpr { child, format } - } -} - -impl Display for DateTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "DateTrunc [child:{}, format: {}]", - self.child, self.format - ) - } -} - -impl PhysicalExpr for DateTruncExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - self.child.data_type(input_schema) - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let date = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; - match (date, format) { - (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let result = date_trunc_dyn(&date, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { - let result = date_trunc_array_fmt_dyn(&date, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ - (PrimitiveArray, StringArray)".to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(DateTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - ))) - } -} - -#[derive(Debug, Eq)] -pub struct TimestampTruncExpr { - /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) - child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc - format: Arc, - /// String containing a timezone name. The name must be found in the standard timezone - /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is - /// later parsed into a chrono::TimeZone. - /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) - /// along with a single value for the associated TimeZone. The timezone offset is applied - /// just before any operations on the timestamp - timezone: String, -} - -impl Hash for TimestampTruncExpr { - fn hash(&self, state: &mut H) { - self.child.hash(state); - self.format.hash(state); - self.timezone.hash(state); - } -} -impl PartialEq for TimestampTruncExpr { - fn eq(&self, other: &Self) -> bool { - self.child.eq(&other.child) - && self.format.eq(&other.format) - && self.timezone.eq(&other.timezone) - } -} - -impl TimestampTruncExpr { - pub fn new( - child: Arc, - format: Arc, - timezone: String, - ) -> Self { - TimestampTruncExpr { - child, - format, - timezone, - } - } -} - -impl Display for TimestampTruncExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "TimestampTrunc [child:{}, format:{}, timezone: {}]", - self.child, self.format, self.timezone - ) - } -} - -impl PhysicalExpr for TimestampTruncExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { - match self.child.data_type(input_schema)? { - DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( - key_type, - Box::new(DataType::Timestamp(Microsecond, None)), - )), - _ => Ok(DataType::Timestamp(Microsecond, None)), - } - } - - fn nullable(&self, _: &Schema) -> datafusion_common::Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { - let timestamp = self.child.evaluate(batch)?; - let format = self.format.evaluate(batch)?; - let tz = self.timezone.clone(); - match (timestamp, format) { - (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { - let ts = array_with_timezone( - ts, - tz.clone(), - Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), - )?; - let result = timestamp_trunc_dyn(&ts, format)?; - Ok(ColumnarValue::Array(result)) - } - (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { - let ts = array_with_timezone( - ts, - tz.clone(), - Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), - )?; - let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; - Ok(ColumnarValue::Array(result)) - } - _ => Err(DataFusionError::Execution( - "Invalid input to function TimestampTrunc. \ - Expected (PrimitiveArray, Scalar, String) or \ - (PrimitiveArray, StringArray, String)" - .to_string(), - )), - } - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.child] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result, DataFusionError> { - Ok(Arc::new(TimestampTruncExpr::new( - Arc::clone(&children[0]), - Arc::clone(&self.format), - self.timezone.clone(), - ))) - } -}