diff --git a/rust/datafusion/Cargo.toml b/rust/datafusion/Cargo.toml index bd2ae974c38..193ad6b807d 100644 --- a/rust/datafusion/Cargo.toml +++ b/rust/datafusion/Cargo.toml @@ -45,6 +45,7 @@ cli = ["rustyline"] [dependencies] ahash = "0.6" +hashbrown = "0.9" arrow = { path = "../arrow", version = "3.0.0-SNAPSHOT", features = ["prettyprint"] } parquet = { path = "../parquet", version = "3.0.0-SNAPSHOT", features = ["arrow"] } sqlparser = "0.6.1" diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index 4e4222d97da..6f12f4c2602 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - #![warn(missing_docs)] // Clippy lints, some should be disabled incrementally #![allow( diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 0c64bc9ea02..6be23263c04 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -47,7 +47,7 @@ use super::{ SendableRecordBatchStream, }; use ahash::RandomState; -use std::collections::HashMap; +use hashbrown::HashMap; use async_trait::async_trait; @@ -253,23 +253,25 @@ fn group_aggregate_batch( // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist // 1.3 add the row' index to `indices` + + // Make sure we can create the accumulators or otherwise return an error + create_accumulators(aggr_expr).map_err(DataFusionError::into_arrow_external_error)?; + for row in 0..batch.num_rows() { // 1.1 create_key(&group_values, row, &mut key) .map_err(DataFusionError::into_arrow_external_error)?; - - match accumulators.get_mut(&key) { - // 1.2 - None => { - let accumulator_set = create_accumulators(aggr_expr) - .map_err(DataFusionError::into_arrow_external_error)?; - - accumulators - .insert(key.clone(), (accumulator_set, Box::new(vec![row as u32]))); - } + accumulators + .raw_entry_mut() + .from_key(&key) // 1.3 - Some((_, v)) => v.push(row as u32), - } + .and_modify(|_, (_, v)| v.push(row as u32)) + // 1.2 + .or_insert_with(|| { + // We can safely unwrap here as we checked we can create an accumulator before + let accumulator_set = create_accumulators(aggr_expr).unwrap(); + (key.clone(), (accumulator_set, Box::new(vec![row as u32]))) + }); } // 2.1 for each key diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 17d9f69545d..d2bb8cf7c41 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -19,13 +19,11 @@ //! into a set of partitions. use std::sync::Arc; -use std::{ - any::Any, - collections::{HashMap, HashSet}, -}; +use std::{any::Any, collections::HashSet}; use async_trait::async_trait; use futures::{Stream, StreamExt, TryStreamExt}; +use hashbrown::HashMap; use arrow::array::{make_array, Array, MutableArrayData}; use arrow::datatypes::{Schema, SchemaRef}; @@ -214,12 +212,11 @@ fn update_hash( // update the hash map for row in 0..batch.num_rows() { create_key(&keys_values, row, &mut key)?; - match hash.get_mut(&key) { - Some(v) => v.push((index, row)), - None => { - hash.insert(key.clone(), vec![(index, row)]); - } - }; + + hash.raw_entry_mut() + .from_key(&key) + .and_modify(|_, v| v.push((index, row))) + .or_insert_with(|| (key.clone(), vec![(index, row)])); } Ok(()) }