Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: Move group accumulator for aggregate function to physical-expr-common, and add ahash physical-expr-common #10574

Merged
merged 8 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ version = "38.0.0"
# for the inherited dependency but cannot do the reverse (override from true to false).
#
# See for more detaiils: https://github.com/rust-lang/cargo/issues/11329
ahash = { version = "0.8", default-features = false, features = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

"runtime-rng",
] }
arrow = { version = "51.0.0", features = ["prettyprint"] }
arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] }
arrow-buffer = { version = "51.0.0", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ backtrace = []
pyarrow = ["pyo3", "arrow/pyarrow", "parquet"]

[dependencies]
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
ahash = { workspace = true }
apache-avro = { version = "0.16", default-features = false, features = [
"bzip",
"snappy",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ unicode_expressions = [
]

[dependencies]
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
ahash = { workspace = true }
apache-avro = { version = "0.16", optional = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
Expand Down
4 changes: 1 addition & 3 deletions datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ path = "src/lib.rs"
[features]

[dependencies]
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
chrono = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ path = "src/lib.rs"
arrow = { workspace = true }
datafusion-common = { workspace = true, default-features = true }
datafusion-expr = { workspace = true }
rand = { workspace = true }
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
//!
//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator

use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::datatypes::ArrowPrimitiveType;
use arrow_array::{Array, BooleanArray, PrimitiveArray};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};

use datafusion_expr::EmitTo;
/// Track the accumulator null state per row: if any values for that
Expand Down Expand Up @@ -462,9 +462,9 @@ fn initialize_builder(
mod test {
use super::*;

use arrow_array::UInt32Array;
use hashbrown::HashSet;
use arrow::array::UInt32Array;
use rand::{rngs::ThreadRng, Rng};
use std::collections::HashSet;

#[test]
fn accumulate() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

use std::sync::Arc;

use arrow::array::AsArray;
use arrow_array::{ArrayRef, BooleanArray};
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
use arrow::array::{ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder};
use arrow::buffer::BooleanBuffer;
use datafusion_common::Result;
use datafusion_expr::{EmitTo, GroupsAccumulator};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Utilities for implementing GroupsAccumulator

pub mod accumulate;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we could add a note that this module is for implementing GroupsAccumulator

Suggested change
pub mod accumulate;
//! Utilities for implementing [`GroupsAccumulator`]
pub mod accumulate;

pub mod bool_op;
pub mod prim_op;
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

use std::sync::Arc;

use arrow::{array::AsArray, datatypes::ArrowPrimitiveType};
use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray};
use arrow_schema::DataType;
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
use arrow::datatypes::ArrowPrimitiveType;
use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_expr::{EmitTo, GroupsAccumulator};

Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.

pub mod groups_accumulator;
pub mod stats;
pub mod utils;

Expand Down
162 changes: 161 additions & 1 deletion datafusion/physical-expr-common/src/aggregate/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@

use std::{any::Any, sync::Arc};

use arrow::datatypes::ArrowNativeType;
use arrow::{
array::{ArrayRef, ArrowNativeTypeOp, AsArray},
compute::SortOptions,
datatypes::{DataType, Field},
datatypes::{
DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
ToByteSlice,
},
};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;

use crate::sort_expr::PhysicalSortExpr;

Expand All @@ -43,6 +51,60 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
}
}

/// Convert scalar values from an accumulator into arrays.
pub fn get_accum_scalar_values_as_arrays(
accum: &mut dyn Accumulator,
) -> Result<Vec<ArrayRef>> {
accum
.state()?
.iter()
.map(|s| s.to_array_of_size(1))
.collect()
}

/// Adjust array type metadata if needed
///
/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
/// default precision and scale, this function adjusts the output to
/// match `data_type`, if necessary
pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
let array = match data_type {
DataType::Decimal128(p, s) => Arc::new(
array
.as_primitive::<Decimal128Type>()
.clone()
.with_precision_and_scale(*p, *s)?,
) as ArrayRef,
DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
array
.as_primitive::<TimestampNanosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
array
.as_primitive::<TimestampMicrosecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
array
.as_primitive::<TimestampMillisecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
array
.as_primitive::<TimestampSecondType>()
.clone()
.with_timezone_opt(tz.clone()),
),
// no adjustment needed for other arrays
_ => array,
};
Ok(array)
}

/// Construct corresponding fields for lexicographical ordering requirement expression
pub fn ordering_fields(
ordering_req: &[PhysicalSortExpr],
Expand All @@ -67,3 +129,101 @@ pub fn ordering_fields(
pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> {
ordering_req.iter().map(|item| item.options).collect()
}

/// A wrapper around a type to provide hash for floats
#[derive(Copy, Clone, Debug)]
pub struct Hashable<T>(pub T);

impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_byte_slice().hash(state)
}
}

impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
fn eq(&self, other: &Self) -> bool {
self.0.is_eq(other.0)
}
}

impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}

/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
///
/// This is needed because different precisions for Decimal128/Decimal256 can
/// store different ranges of values and thus sum/count may not fit in
/// the target type.
///
/// For example, the precision is 3, the max of value is `999` and the min
/// value is `-999`
pub struct DecimalAverager<T: DecimalType> {
/// scale factor for sum values (10^sum_scale)
sum_mul: T::Native,
/// scale factor for target (10^target_scale)
target_mul: T::Native,
/// the output precision
target_precision: u8,
}

impl<T: DecimalType> DecimalAverager<T> {
/// Create a new `DecimalAverager`:
///
/// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
/// * target_precision: the output precision
/// * target_scale: the output scale
///
/// Errors if the resulting data can not be stored
pub fn try_new(
sum_scale: i8,
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(sum_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute sum_mul in DecimalAverager".to_string(),
))?;

let target_mul = T::Native::from_usize(10_usize)
.map(|b| b.pow_wrapping(target_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute target_mul in DecimalAverager".to_string(),
))?;

if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_precision,
})
} else {
// can't convert the lit decimal to the returned data type
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}

/// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
/// target_scale and target_precision and reporting overflow.
///
/// * sum: The total sum value stored as Decimal128 with sum_scale
/// (passed to `Self::try_new`)
/// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
#[inline(always)]
pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
let new_value = value.div_wrapping(count);

let validate =
T::validate_decimal_precision(new_value, self.target_precision);

if validate.is_ok() {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
} else {
// can't convert the lit decimal to the returned data type
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}
}
4 changes: 1 addition & 3 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ encoding_expressions = ["base64", "hex"]
regex_expressions = ["regex"]

[dependencies]
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
ahash = { workspace = true }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
Expand Down
17 changes: 13 additions & 4 deletions datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
// specific language governing permissions and limitations
// under the License.

pub(crate) mod accumulate;
mod adapter;
pub use accumulate::NullState;
pub use adapter::GroupsAccumulatorAdapter;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for backwards compatibility

Maybe we can leave a comment explaining it is for backwards compatibility

pub(crate) mod bool_op;
pub(crate) mod prim_op;
// Backward compatibility
pub(crate) mod accumulate {
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState};
}

pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;

pub(crate) mod bool_op {
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
}
pub(crate) mod prim_op {
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
}
7 changes: 6 additions & 1 deletion datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ pub(crate) mod variance;

pub mod build_in;
pub mod moving_min_max;
pub mod utils;
pub mod utils {
pub use datafusion_physical_expr_common::aggregate::utils::{
adjust_output_array, down_cast_any_ref, get_accum_scalar_values_as_arrays,
get_sort_options, ordering_fields, DecimalAverager, Hashable,
};
}

/// Checks whether the given aggregate expression is order-sensitive.
/// For instance, a `SUM` aggregation doesn't depend on the order of its inputs.
Expand Down
Loading