From 815a816a7166615a4a23cfeff3c35e503af08b71 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Jan 2025 11:25:17 +0200 Subject: [PATCH 1/2] extract static invoke expressions to folders based on spark grouping --- native/spark-expr/src/comet_scalar_funcs.rs | 5 +- native/spark-expr/src/lib.rs | 2 + native/spark-expr/src/scalar_funcs.rs | 59 +------------- .../static_invoke/char_varchar_utils/mod.rs | 20 +++++ .../char_varchar_utils/read_side_padding.rs | 76 +++++++++++++++++++ native/spark-expr/src/static_invoke/mod.rs | 20 +++++ 6 files changed, 124 insertions(+), 58 deletions(-) create mode 100644 native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs create mode 100644 native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs create mode 100644 native/spark-expr/src/static_invoke/mod.rs diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 71ff0e9dcc..a4c2e9d70e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,9 +20,10 @@ use crate::scalar_funcs::hash_expressions::{ }; use crate::scalar_funcs::{ spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex, - spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_read_side_padding, spark_round, - spark_unhex, spark_unscaled_value, spark_xxhash64, SparkChrFunc, + spark_isnan, spark_make_decimal, spark_murmur3_hash, spark_round, spark_unhex, + spark_unscaled_value, spark_xxhash64, SparkChrFunc, }; +use crate::spark_read_side_padding; use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result as DataFusionResult}; use datafusion_expr::registry::FunctionRegistry; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index f358731004..ab5039a457 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -42,7 +42,9 @@ mod list; mod regexp; pub mod scalar_funcs; mod schema_adapter; +mod static_invoke; pub use schema_adapter::SparkSchemaAdapterFactory; +pub use static_invoke::*; pub mod spark_hash; mod stddev; diff --git a/native/spark-expr/src/scalar_funcs.rs b/native/spark-expr/src/scalar_funcs.rs index 2961f038dc..345a800a30 100644 --- a/native/spark-expr/src/scalar_funcs.rs +++ b/native/spark-expr/src/scalar_funcs.rs @@ -20,25 +20,23 @@ use arrow::datatypes::IntervalDayTime; use arrow::{ array::{ ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int64Builder, Int8Array, OffsetSizeTrait, + Int64Array, Int64Builder, Int8Array, }, datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, }; -use arrow_array::builder::{GenericStringBuilder, IntervalDayTimeBuilder}; +use arrow_array::builder::IntervalDayTimeBuilder; use arrow_array::types::{Int16Type, Int32Type, Int8Type}; use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Datum, Decimal128Array}; use arrow_schema::{ArrowError, DataType, DECIMAL128_MAX_PRECISION}; use datafusion::physical_expr_common::datum; use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; use datafusion_common::{ - cast::as_generic_string_array, exec_err, internal_err, DataFusionError, - Result as DataFusionResult, ScalarValue, + exec_err, internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, }; use num::{ integer::{div_ceil, div_floor}, BigInt, Signed, ToPrimitive, }; -use std::fmt::Write; use std::{cmp::min, sync::Arc}; mod unhex; @@ -390,57 +388,6 @@ pub fn spark_round( } } -/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length -pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { - match args { - [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { - match array.data_type() { - DataType::Utf8 => spark_read_side_padding_internal::(array, *length), - DataType::LargeUtf8 => spark_read_side_padding_internal::(array, *length), - // TODO: handle Dictionary types - other => Err(DataFusionError::Internal(format!( - "Unsupported data type {other:?} for function read_side_padding", - ))), - } - } - other => Err(DataFusionError::Internal(format!( - "Unsupported arguments {other:?} for function read_side_padding", - ))), - } -} - -fn spark_read_side_padding_internal( - array: &ArrayRef, - length: i32, -) -> Result { - let string_array = as_generic_string_array::(array)?; - let length = 0.max(length) as usize; - let space_string = " ".repeat(length); - - let mut builder = - GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); - - for string in string_array.iter() { - match string { - Some(string) => { - // It looks Spark's UTF8String is closer to chars rather than graphemes - // https://stackoverflow.com/a/46290728 - let char_len = string.chars().count(); - if length <= char_len { - builder.append_value(string); - } else { - // write_str updates only the value buffer, not null nor offset buffer - // This is convenient for concatenating str(s) - builder.write_str(string)?; - builder.append_value(&space_string[char_len..]); - } - } - _ => builder.append_null(), - } - } - Ok(ColumnarValue::Array(Arc::new(builder.finish()))) -} - // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). // Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to // get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs new file mode 100644 index 0000000000..fff6134dab --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod read_side_padding; + +pub use read_side_padding::spark_read_side_padding; diff --git a/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs new file mode 100644 index 0000000000..15807bf57d --- /dev/null +++ b/native/spark-expr/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow_array::builder::GenericStringBuilder; +use arrow_array::Array; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use std::fmt::Write; +use std::sync::Arc; + +/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length +pub fn spark_read_side_padding(args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { + match array.data_type() { + DataType::Utf8 => spark_read_side_padding_internal::(array, *length), + DataType::LargeUtf8 => spark_read_side_padding_internal::(array, *length), + // TODO: handle Dictionary types + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function read_side_padding", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for function read_side_padding", + ))), + } +} + +fn spark_read_side_padding_internal( + array: &ArrayRef, + length: i32, +) -> Result { + let string_array = as_generic_string_array::(array)?; + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = + GenericStringBuilder::::with_capacity(string_array.len(), string_array.len() * length); + + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + builder.append_value(string); + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); + } + } + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs new file mode 100644 index 0000000000..2a5351bb7f --- /dev/null +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod char_varchar_utils; + +pub use char_varchar_utils::*; From c0a368ba49a3b9d346f8c028cb44e44235ccc5e7 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Jan 2025 21:17:25 +0200 Subject: [PATCH 2/2] Update native/spark-expr/src/static_invoke/mod.rs Co-authored-by: Andy Grove --- native/spark-expr/src/static_invoke/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/static_invoke/mod.rs b/native/spark-expr/src/static_invoke/mod.rs index 2a5351bb7f..4072e13b70 100644 --- a/native/spark-expr/src/static_invoke/mod.rs +++ b/native/spark-expr/src/static_invoke/mod.rs @@ -17,4 +17,4 @@ mod char_varchar_utils; -pub use char_varchar_utils::*; +pub use char_varchar_utils::spark_read_side_padding;