diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 68daad92a87..3916e614b93 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -43,8 +43,8 @@ use arrow::{ use pin_project_lite::pin_project; use super::{ - common, expressions::Column, group_scalar::GroupByScalar, RecordBatchStream, - SendableRecordBatchStream, + common, expressions::Column, group_scalar::GroupByScalar, hash_join::create_key, + RecordBatchStream, SendableRecordBatchStream, }; use ahash::RandomState; use hashbrown::HashMap; @@ -245,12 +245,14 @@ fn group_aggregate_batch( // create vector large enough to hold the grouping key // this is an optimization to avoid allocating `key` on every row. // it will be overwritten on every iteration of the loop below - let mut key = Vec::with_capacity(group_values.len()); + let mut group_by_values = Vec::with_capacity(group_values.len()); for _ in 0..group_values.len() { - key.push(GroupByScalar::UInt32(0)); + group_by_values.push(GroupByScalar::UInt32(0)); } - let mut key = key.into_boxed_slice(); + let mut group_by_values = group_by_values.into_boxed_slice(); + + let mut key = Vec::with_capacity(group_values.len()); // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist @@ -263,16 +265,21 @@ fn group_aggregate_batch( // 1.1 create_key(&group_values, row, &mut key) .map_err(DataFusionError::into_arrow_external_error)?; + accumulators .raw_entry_mut() .from_key(&key) // 1.3 - .and_modify(|_, (_, v)| v.push(row as u32)) + .and_modify(|_, (_, _, v)| v.push(row as u32)) // 1.2 .or_insert_with(|| { // We can safely unwrap here as we checked we can create an accumulator before let accumulator_set = create_accumulators(aggr_expr).unwrap(); - (key.clone(), (accumulator_set, vec![row as u32])) + let _ = create_group_by_values(&group_values, row, &mut group_by_values); + ( + key.clone(), + (group_by_values.clone(), accumulator_set, vec![row as u32]), + ) }); } @@ -284,7 +291,7 @@ fn group_aggregate_batch( accumulators .iter_mut() // 2.1 - .map(|(_, (accumulator_set, indices))| { + .map(|(_, (_, accumulator_set, indices))| { // 2.2 accumulator_set .into_iter() @@ -391,7 +398,7 @@ impl GroupedHashAggregateStream { type AccumulatorSet = Vec>; type Accumulators = - HashMap, (AccumulatorSet, Vec), RandomState>; + HashMap, (Box<[GroupByScalar]>, AccumulatorSet, Vec), RandomState>; impl Stream for GroupedHashAggregateStream { type Item = ArrowResult; @@ -646,10 +653,10 @@ fn create_batch_from_map( // 5. concatenate the arrays over the second index [j] into a single vec. let arrays = accumulators .iter() - .map(|(k, (accumulator_set, _))| { + .map(|(_, (group_by_values, accumulator_set, _))| { // 2. let mut groups = (0..num_group_expr) - .map(|i| match &k[i] { + .map(|i| match &group_by_values[i] { GroupByScalar::Int8(n) => { Arc::new(Int8Array::from(vec![*n])) as ArrayRef } @@ -726,8 +733,8 @@ fn finalize_aggregation( } } -/// Create a Vec that can be used as a map key -pub(crate) fn create_key( +/// Create a Box<[GroupByScalar]> for the group by values +pub(crate) fn create_group_by_values( group_by_keys: &[ArrayRef], row: usize, vec: &mut Box<[GroupByScalar]>, diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 971968a4efd..1bcfe2c5704 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -18,6 +18,7 @@ //! Defines the join plan for executing partitions in parallel and then joining the results //! into a set of partitions. +use arrow::array::ArrayRef; use std::sync::Arc; use std::{any::Any, collections::HashSet}; @@ -26,21 +27,24 @@ use futures::{Stream, StreamExt, TryStreamExt}; use hashbrown::HashMap; use arrow::array::{make_array, Array, MutableArrayData}; +use arrow::datatypes::DataType; use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use super::{expressions::col, hash_aggregate::create_key}; +use arrow::array::{ + Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, +}; + +use super::expressions::col; use super::{ hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType}, merge::MergeExec, }; use crate::error::{DataFusionError, Result}; -use super::{ - group_scalar::GroupByScalar, ExecutionPlan, Partitioning, RecordBatchStream, - SendableRecordBatchStream, -}; +use super::{ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream}; use ahash::RandomState; // An index of (batch, row) uniquely identifying a row in a part. @@ -52,7 +56,7 @@ type JoinIndex = Option<(usize, usize)>; // Maps ["on" value] -> [list of indices with this key's value] // E.g. [1, 2] -> [(0, 3), (1, 6), (0, 8)] indicates that (column1, column2) = [1, 2] is true // for rows 3 and 8 from batch 0 and row 6 from batch 1. -type JoinHashMap = HashMap, Vec, RandomState>; +type JoinHashMap = HashMap, Vec, RandomState>; type JoinLeftData = (JoinHashMap, Vec); /// join execution plan executes partitions in parallel and combines them into a set of @@ -205,11 +209,6 @@ fn update_hash( .collect::>>()?; let mut key = Vec::with_capacity(keys_values.len()); - for _ in 0..keys_values.len() { - key.push(GroupByScalar::UInt32(0)); - } - - let mut key = key.into_boxed_slice(); // update the hash map for row in 0..batch.num_rows() { @@ -318,6 +317,67 @@ fn build_batch_from_indices( Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?) } +/// Create a key `Vec` that is used as key for the hashmap +pub(crate) fn create_key( + group_by_keys: &[ArrayRef], + row: usize, + vec: &mut Vec, +) -> Result<()> { + vec.clear(); + for i in 0..group_by_keys.len() { + let col = &group_by_keys[i]; + match col.data_type() { + DataType::UInt8 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::UInt16 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::UInt32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::UInt64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::Int8 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::Int16 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::Int32 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::Int64 => { + let array = col.as_any().downcast_ref::().unwrap(); + vec.extend(array.value(row).to_le_bytes().iter()); + } + DataType::Utf8 => { + let array = col.as_any().downcast_ref::().unwrap(); + let value = array.value(row); + // store the size + vec.extend(value.len().to_le_bytes().iter()); + // store the string value + vec.extend(array.value(row).as_bytes().iter()); + } + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal( + "Unsupported GROUP BY data type".to_string(), + )); + } + } + } + Ok(()) +} + fn build_batch( batch: &RecordBatch, left_data: &JoinLeftData, @@ -370,9 +430,8 @@ fn build_join_indexes( JoinType::Inner => { // inner => key intersection // unfortunately rust does not support intersection of map keys :( - let left_set: HashSet> = left.keys().cloned().collect(); - let left_right: HashSet> = - right.keys().cloned().collect(); + let left_set: HashSet> = left.keys().cloned().collect(); + let left_right: HashSet> = right.keys().cloned().collect(); let inner = left_set.intersection(&left_right); let mut indexes = Vec::new(); // unknown a prior size