From efcf5c6caa18464d17f104d29fd166b2c9d986b1 Mon Sep 17 00:00:00 2001 From: Xiangpeng Hao Date: Fri, 19 Jul 2024 20:40:22 -0400 Subject: [PATCH] Enable `GroupValueBytesView` for aggregation with StringView types (#11519) * add functions * Update `string-view` branch to arrow-rs main (#10966) * Pin to arrow main * Fix clippy with latest arrow * Uncomment test that needs new arrow-rs to work * Update datafusion-cli Cargo.lock * Update Cargo.lock * tapelo * merge * update cast * consistent dep * fix ci * avoid unused dep * update dep * update * fix cargo check * better group value view aggregation * update --------- Co-authored-by: Andrew Lamb --- datafusion/functions-aggregate/src/count.rs | 4 + .../src/aggregate/count_distinct/bytes.rs | 61 +++++++++ .../src/aggregate/count_distinct/mod.rs | 1 + .../physical-expr-common/src/binary_map.rs | 6 + .../src/binary_view_map.rs | 21 +-- .../src/aggregates/group_values/bytes_view.rs | 129 ++++++++++++++++++ .../src/aggregates/group_values/mod.rs | 33 +++-- 7 files changed, 236 insertions(+), 19 deletions(-) create mode 100644 datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 0a667d35dce5..7d190482f255 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,6 +16,7 @@ // under the License. use ahash::RandomState; +use datafusion_physical_expr_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use std::collections::HashSet; use std::ops::BitAnd; use std::{fmt::Debug, sync::Arc}; @@ -230,6 +231,9 @@ impl AggregateUDFImpl for Count { DataType::Utf8 => { Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) } + DataType::Utf8View => { + Box::new(BytesViewDistinctCountAccumulator::new(OutputType::Utf8)) + } DataType::LargeUtf8 => { Box::new(BytesDistinctCountAccumulator::::new(OutputType::Utf8)) } diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs index 27094b0c819a..360d64ce0141 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/bytes.rs @@ -18,6 +18,7 @@ //! [`BytesDistinctCountAccumulator`] for Utf8/LargeUtf8/Binary/LargeBinary values use crate::binary_map::{ArrowBytesSet, OutputType}; +use crate::binary_view_map::ArrowBytesViewSet; use arrow::array::{ArrayRef, OffsetSizeTrait}; use datafusion_common::cast::as_list_array; use datafusion_common::utils::array_into_list_array_nullable; @@ -88,3 +89,63 @@ impl Accumulator for BytesDistinctCountAccumulator { std::mem::size_of_val(self) + self.0.size() } } + +/// Specialized implementation of +/// `COUNT DISTINCT` for [`StringViewArray`] and [`BinaryViewArray`]. +/// +/// [`StringViewArray`]: arrow::array::StringViewArray +/// [`BinaryViewArray`]: arrow::array::BinaryViewArray +#[derive(Debug)] +pub struct BytesViewDistinctCountAccumulator(ArrowBytesViewSet); + +impl BytesViewDistinctCountAccumulator { + pub fn new(output_type: OutputType) -> Self { + Self(ArrowBytesViewSet::new(output_type)) + } +} + +impl Accumulator for BytesViewDistinctCountAccumulator { + fn state(&mut self) -> datafusion_common::Result> { + let set = self.0.take(); + let arr = set.into_state(); + let list = Arc::new(array_into_list_array_nullable(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() { + return Ok(()); + } + + self.0.insert(&values[0]); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + self.0.insert(&list); + }; + Ok(()) + }) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.0.non_null_len() as i64))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + self.0.size() + } +} diff --git a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs index f216406d0dd7..7d772f7c649d 100644 --- a/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/count_distinct/mod.rs @@ -19,5 +19,6 @@ mod bytes; mod native; pub use bytes::BytesDistinctCountAccumulator; +pub use bytes::BytesViewDistinctCountAccumulator; pub use native::FloatDistinctCountAccumulator; pub use native::PrimitiveDistinctCountAccumulator; diff --git a/datafusion/physical-expr-common/src/binary_map.rs b/datafusion/physical-expr-common/src/binary_map.rs index bff571f5b5be..548ca16e4dbf 100644 --- a/datafusion/physical-expr-common/src/binary_map.rs +++ b/datafusion/physical-expr-common/src/binary_map.rs @@ -40,8 +40,12 @@ use std::sync::Arc; pub enum OutputType { /// `StringArray` or `LargeStringArray` Utf8, + /// `StringViewArray` + Utf8View, /// `BinaryArray` or `LargeBinaryArray` Binary, + /// `BinaryViewArray` + BinaryView, } /// HashSet optimized for storing string or binary values that can produce that @@ -318,6 +322,7 @@ where observe_payload_fn, ) } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), }; } @@ -516,6 +521,7 @@ where GenericStringArray::new_unchecked(offsets, values, nulls) }) } + _ => unreachable!("View types should use `ArrowBytesViewMap`"), } } diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index 3eeab4a5af02..db4e38501248 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -28,14 +28,7 @@ use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::fmt::Debug; use std::sync::Arc; -/// Should the output be a String or Binary? -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum OutputType { - /// `StringViewArray` - Utf8View, - /// `BinaryViewArray` - BinaryView, -} +use crate::binary_map::OutputType; /// HashSet optimized for storing string or binary values that can produce that /// the final set as a `GenericBinaryViewArray` with minimal copies. @@ -55,6 +48,14 @@ impl ArrowBytesViewSet { .insert_if_new(values, make_payload_fn, observe_payload_fn); } + /// Return the contents of this map and replace it with a new empty map with + /// the same output type + pub fn take(&mut self) -> Self { + let mut new_self = Self::new(self.0.output_type); + std::mem::swap(self, &mut new_self); + new_self + } + /// Converts this set into a `StringViewArray` or `BinaryViewArray` /// containing each distinct value that was interned. /// This is done without copying the values. @@ -216,6 +217,7 @@ where observe_payload_fn, ) } + _ => unreachable!("Utf8/Binary should use `ArrowBytesSet`"), }; } @@ -327,6 +329,9 @@ where let array = unsafe { array.to_string_view_unchecked() }; Arc::new(array) } + _ => { + unreachable!("Utf8/Binary should use `ArrowBytesMap`") + } } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs new file mode 100644 index 000000000000..1a0cb90a16d4 --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/bytes_view.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use arrow_array::{Array, ArrayRef, RecordBatch}; +use datafusion_expr::EmitTo; +use datafusion_physical_expr::binary_map::OutputType; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; + +/// A [`GroupValues`] storing single column of Utf8View/BinaryView values +/// +/// This specialization is significantly faster than using the more general +/// purpose `Row`s format +pub struct GroupValuesBytesView { + /// Map string/binary values to group index + map: ArrowBytesViewMap, + /// The total number of groups so far (used to assign group_index) + num_groups: usize, +} + +impl GroupValuesBytesView { + pub fn new(output_type: OutputType) -> Self { + Self { + map: ArrowBytesViewMap::new(output_type), + num_groups: 0, + } + } +} + +impl GroupValues for GroupValuesBytesView { + fn intern( + &mut self, + cols: &[ArrayRef], + groups: &mut Vec, + ) -> datafusion_common::Result<()> { + assert_eq!(cols.len(), 1); + + // look up / add entries in the table + let arr = &cols[0]; + + groups.clear(); + self.map.insert_if_new( + arr, + // called for each new group + |_value| { + // assign new group index on each insert + let group_idx = self.num_groups; + self.num_groups += 1; + group_idx + }, + // called for each group + |group_idx| { + groups.push(group_idx); + }, + ); + + // ensure we assigned a group to for each row + assert_eq!(groups.len(), arr.len()); + Ok(()) + } + + fn size(&self) -> usize { + self.map.size() + std::mem::size_of::() + } + + fn is_empty(&self) -> bool { + self.num_groups == 0 + } + + fn len(&self) -> usize { + self.num_groups + } + + fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { + // Reset the map to default, and convert it into a single array + let map_contents = self.map.take().into_state(); + + let group_values = match emit_to { + EmitTo::All => { + self.num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) if n == self.len() => { + self.num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) => { + // if we only wanted to take the first n, insert the rest back + // into the map we could potentially avoid this reallocation, at + // the expense of much more complex code. + // see https://github.com/apache/datafusion/issues/9195 + let emit_group_values = map_contents.slice(0, n); + let remaining_group_values = + map_contents.slice(n, map_contents.len() - n); + + self.num_groups = 0; + let mut group_indexes = vec![]; + self.intern(&[remaining_group_values], &mut group_indexes)?; + + // Verify that the group indexes were assigned in the correct order + assert_eq!(0, group_indexes[0]); + + emit_group_values + } + }; + + Ok(vec![group_values]) + } + + fn clear_shrink(&mut self, _batch: &RecordBatch) { + // in theory we could potentially avoid this reallocation and clear the + // contents of the maps, but for now we just reset the map from the beginning + self.map.take(); + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index b5bc923b467d..be7ac934d7bc 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -18,6 +18,7 @@ use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::{DataType, SchemaRef}; +use bytes_view::GroupValuesBytesView; use datafusion_common::Result; pub(crate) mod primitive; @@ -28,6 +29,7 @@ mod row; use row::GroupValuesRows; mod bytes; +mod bytes_view; use bytes::GroupValuesByes; use datafusion_physical_expr::binary_map::OutputType; @@ -67,17 +69,26 @@ pub fn new_group_values(schema: SchemaRef) -> Result> { _ => {} } - if let DataType::Utf8 = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); - } - if let DataType::LargeUtf8 = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); - } - if let DataType::Binary = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); - } - if let DataType::LargeBinary = d { - return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + match d { + DataType::Utf8 => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); + } + DataType::LargeUtf8 => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Utf8))); + } + DataType::Utf8View => { + return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View))); + } + DataType::Binary => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + } + DataType::LargeBinary => { + return Ok(Box::new(GroupValuesByes::::new(OutputType::Binary))); + } + DataType::BinaryView => { + return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView))); + } + _ => {} } }