Skip to content
Closed
33 changes: 20 additions & 13 deletions rust/datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ use arrow::{
use pin_project_lite::pin_project;

use super::{
common, expressions::Column, group_scalar::GroupByScalar, RecordBatchStream,
SendableRecordBatchStream,
common, expressions::Column, group_scalar::GroupByScalar, hash_join::create_key,
RecordBatchStream, SendableRecordBatchStream,
};
use ahash::RandomState;
use hashbrown::HashMap;
Expand Down Expand Up @@ -245,12 +245,14 @@ fn group_aggregate_batch(
// create vector large enough to hold the grouping key
// this is an optimization to avoid allocating `key` on every row.
// it will be overwritten on every iteration of the loop below
let mut key = Vec::with_capacity(group_values.len());
let mut group_by_values = Vec::with_capacity(group_values.len());
for _ in 0..group_values.len() {
key.push(GroupByScalar::UInt32(0));
group_by_values.push(GroupByScalar::UInt32(0));
}

let mut key = key.into_boxed_slice();
let mut group_by_values = group_by_values.into_boxed_slice();

let mut key = Vec::with_capacity(group_values.len());

// 1.1 construct the key from the group values
// 1.2 construct the mapping key if it does not exist
Expand All @@ -263,16 +265,21 @@ fn group_aggregate_batch(
// 1.1
create_key(&group_values, row, &mut key)
.map_err(DataFusionError::into_arrow_external_error)?;

accumulators
.raw_entry_mut()
.from_key(&key)
// 1.3
.and_modify(|_, (_, v)| v.push(row as u32))
.and_modify(|_, (_, _, v)| v.push(row as u32))
// 1.2
.or_insert_with(|| {
// We can safely unwrap here as we checked we can create an accumulator before
let accumulator_set = create_accumulators(aggr_expr).unwrap();
(key.clone(), (accumulator_set, vec![row as u32]))
let _ = create_group_by_values(&group_values, row, &mut group_by_values);
(
key.clone(),
(group_by_values.clone(), accumulator_set, vec![row as u32]),
)
});
}

Expand All @@ -284,7 +291,7 @@ fn group_aggregate_batch(
accumulators
.iter_mut()
// 2.1
.map(|(_, (accumulator_set, indices))| {
.map(|(_, (_, accumulator_set, indices))| {
// 2.2
accumulator_set
.into_iter()
Expand Down Expand Up @@ -391,7 +398,7 @@ impl GroupedHashAggregateStream {

type AccumulatorSet = Vec<Box<dyn Accumulator>>;
type Accumulators =
HashMap<Box<[GroupByScalar]>, (AccumulatorSet, Vec<u32>), RandomState>;
HashMap<Vec<u8>, (Box<[GroupByScalar]>, AccumulatorSet, Vec<u32>), RandomState>;

impl Stream for GroupedHashAggregateStream {
type Item = ArrowResult<RecordBatch>;
Expand Down Expand Up @@ -646,10 +653,10 @@ fn create_batch_from_map(
// 5. concatenate the arrays over the second index [j] into a single vec<ArrayRef>.
let arrays = accumulators
.iter()
.map(|(k, (accumulator_set, _))| {
.map(|(_, (group_by_values, accumulator_set, _))| {
// 2.
let mut groups = (0..num_group_expr)
.map(|i| match &k[i] {
.map(|i| match &group_by_values[i] {
GroupByScalar::Int8(n) => {
Arc::new(Int8Array::from(vec![*n])) as ArrayRef
}
Expand Down Expand Up @@ -726,8 +733,8 @@ fn finalize_aggregation(
}
}

/// Create a Vec<GroupByScalar> that can be used as a map key
pub(crate) fn create_key(
/// Create a Box<[GroupByScalar]> for the group by values
pub(crate) fn create_group_by_values(
group_by_keys: &[ArrayRef],
row: usize,
vec: &mut Box<[GroupByScalar]>,
Expand Down
87 changes: 73 additions & 14 deletions rust/datafusion/src/physical_plan/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Defines the join plan for executing partitions in parallel and then joining the results
//! into a set of partitions.

use arrow::array::ArrayRef;
use std::sync::Arc;
use std::{any::Any, collections::HashSet};

Expand All @@ -26,21 +27,24 @@ use futures::{Stream, StreamExt, TryStreamExt};
use hashbrown::HashMap;

use arrow::array::{make_array, Array, MutableArrayData};
use arrow::datatypes::DataType;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;

use super::{expressions::col, hash_aggregate::create_key};
use arrow::array::{
Int16Array, Int32Array, Int64Array, Int8Array, StringArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
};

use super::expressions::col;
use super::{
hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType},
merge::MergeExec,
};
use crate::error::{DataFusionError, Result};

use super::{
group_scalar::GroupByScalar, ExecutionPlan, Partitioning, RecordBatchStream,
SendableRecordBatchStream,
};
use super::{ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream};
use ahash::RandomState;

// An index of (batch, row) uniquely identifying a row in a part.
Expand All @@ -52,7 +56,7 @@ type JoinIndex = Option<(usize, usize)>;
// Maps ["on" value] -> [list of indices with this key's value]
// E.g. [1, 2] -> [(0, 3), (1, 6), (0, 8)] indicates that (column1, column2) = [1, 2] is true
// for rows 3 and 8 from batch 0 and row 6 from batch 1.
type JoinHashMap = HashMap<Box<[GroupByScalar]>, Vec<Index>, RandomState>;
type JoinHashMap = HashMap<Vec<u8>, Vec<Index>, RandomState>;
type JoinLeftData = (JoinHashMap, Vec<RecordBatch>);

/// join execution plan executes partitions in parallel and combines them into a set of
Expand Down Expand Up @@ -205,11 +209,6 @@ fn update_hash(
.collect::<Result<Vec<_>>>()?;

let mut key = Vec::with_capacity(keys_values.len());
for _ in 0..keys_values.len() {
key.push(GroupByScalar::UInt32(0));
}

let mut key = key.into_boxed_slice();

// update the hash map
for row in 0..batch.num_rows() {
Expand Down Expand Up @@ -318,6 +317,67 @@ fn build_batch_from_indices(
Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
}

/// Create a key `Vec<u8>` that is used as key for the hashmap
pub(crate) fn create_key(
group_by_keys: &[ArrayRef],
row: usize,
vec: &mut Vec<u8>,
) -> Result<()> {
vec.clear();
for i in 0..group_by_keys.len() {
let col = &group_by_keys[i];
match col.data_type() {
DataType::UInt8 => {
let array = col.as_any().downcast_ref::<UInt8Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::UInt16 => {
let array = col.as_any().downcast_ref::<UInt16Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::UInt32 => {
let array = col.as_any().downcast_ref::<UInt32Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::UInt64 => {
let array = col.as_any().downcast_ref::<UInt64Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::Int8 => {
let array = col.as_any().downcast_ref::<Int8Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::Int16 => {
let array = col.as_any().downcast_ref::<Int16Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::Int32 => {
let array = col.as_any().downcast_ref::<Int32Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::Int64 => {
let array = col.as_any().downcast_ref::<Int64Array>().unwrap();
vec.extend(array.value(row).to_le_bytes().iter());
}
DataType::Utf8 => {
let array = col.as_any().downcast_ref::<StringArray>().unwrap();
let value = array.value(row);
// store the size
vec.extend(value.len().to_le_bytes().iter());
// store the string value
vec.extend(array.value(row).as_bytes().iter());
}
_ => {
// This is internal because we should have caught this before.
return Err(DataFusionError::Internal(
"Unsupported GROUP BY data type".to_string(),
));
}
}
}
Ok(())
}

fn build_batch(
batch: &RecordBatch,
left_data: &JoinLeftData,
Expand Down Expand Up @@ -370,9 +430,8 @@ fn build_join_indexes(
JoinType::Inner => {
// inner => key intersection
// unfortunately rust does not support intersection of map keys :(
let left_set: HashSet<Box<[GroupByScalar]>> = left.keys().cloned().collect();
let left_right: HashSet<Box<[GroupByScalar]>> =
right.keys().cloned().collect();
let left_set: HashSet<Vec<u8>> = left.keys().cloned().collect();
let left_right: HashSet<Vec<u8>> = right.keys().cloned().collect();
let inner = left_set.intersection(&left_right);

let mut indexes = Vec::new(); // unknown a prior size
Expand Down