diff --git a/rust/datafusion/benches/aggregate_query_sql.rs b/rust/datafusion/benches/aggregate_query_sql.rs index 54741043fd2..1f9f56044b6 100644 --- a/rust/datafusion/benches/aggregate_query_sql.rs +++ b/rust/datafusion/benches/aggregate_query_sql.rs @@ -19,8 +19,7 @@ extern crate criterion; use criterion::Criterion; -use rand::seq::SliceRandom; -use rand::Rng; +use rand::{rngs::StdRng, seq::SliceRandom, Rng, SeedableRng}; use std::sync::{Arc, Mutex}; use tokio::runtime::Runtime; @@ -40,6 +39,10 @@ use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +pub fn seedable_rng() -> StdRng { + StdRng::seed_from_u64(42) +} + fn query(ctx: Arc>, sql: &str) { let mut rt = Runtime::new().unwrap(); @@ -50,7 +53,7 @@ fn query(ctx: Arc>, sql: &str) { fn create_data(size: usize, null_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = rand::thread_rng(); + let mut rng = seedable_rng(); (0..size) .map(|_| { @@ -65,7 +68,7 @@ fn create_data(size: usize, null_density: f64) -> Vec> { fn create_integer_data(size: usize, value_density: f64) -> Vec> { // use random numbers to avoid spurious compiler optimizations wrt to branching - let mut rng = rand::thread_rng(); + let mut rng = seedable_rng(); (0..size) .map(|_| { @@ -98,6 +101,8 @@ fn create_context( Field::new("u64_narrow", DataType::UInt64, false), ])); + let mut rng = seedable_rng(); + // define data. let partitions = (0..partitions_len) .map(|_| { @@ -109,7 +114,7 @@ fn create_context( let keys: Vec = (0..batch_size) .map( // use random numbers to avoid spurious compiler optimizations wrt to branching - |_| format!("hi{:?}", vs.choose(&mut rand::thread_rng())), + |_| format!("hi{:?}", vs.choose(&mut rng)), ) .collect(); let keys: Vec<&str> = keys.iter().map(|e| &**e).collect(); @@ -122,11 +127,7 @@ fn create_context( // Integer values between [0, 9]. let integer_values_narrow_choices = (0..10).collect::>(); let integer_values_narrow = (0..batch_size) - .map(|_| { - *integer_values_narrow_choices - .choose(&mut rand::thread_rng()) - .unwrap() - }) + .map(|_| *integer_values_narrow_choices.choose(&mut rng).unwrap()) .collect::>(); RecordBatch::try_new( @@ -216,6 +217,27 @@ fn criterion_benchmark(c: &mut Criterion) { ) }) }); + + c.bench_function("aggregate_query_group_by_u64 15 12", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_narrow, MIN(f64), AVG(f64), COUNT(f64) \ + FROM t GROUP BY u64_narrow", + ) + }) + }); + + c.bench_function("aggregate_query_group_by_with_filter_u64 15 12", |b| { + b.iter(|| { + query( + ctx.clone(), + "SELECT u64_narrow, MIN(f64), AVG(f64), COUNT(f64) \ + FROM t \ + WHERE f32 > 10 AND f32 < 20 GROUP BY u64_narrow", + ) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs index bb1e204c7f5..8c11a6be65a 100644 --- a/rust/datafusion/src/physical_plan/group_scalar.rs +++ b/rust/datafusion/src/physical_plan/group_scalar.rs @@ -34,7 +34,7 @@ pub(crate) enum GroupByScalar { Int16(i16), Int32(i32), Int64(i64), - Utf8(String), + Utf8(Box), } impl TryFrom<&ScalarValue> for GroupByScalar { @@ -50,7 +50,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar { ScalarValue::UInt16(Some(v)) => GroupByScalar::UInt16(*v), ScalarValue::UInt32(Some(v)) => GroupByScalar::UInt32(*v), ScalarValue::UInt64(Some(v)) => GroupByScalar::UInt64(*v), - ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(v.clone()), + ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())), ScalarValue::Int8(None) | ScalarValue::Int16(None) | ScalarValue::Int32(None) @@ -86,7 +86,7 @@ impl From<&GroupByScalar> for ScalarValue { GroupByScalar::UInt16(v) => ScalarValue::UInt16(Some(*v)), GroupByScalar::UInt32(v) => ScalarValue::UInt32(Some(*v)), GroupByScalar::UInt64(v) => ScalarValue::UInt64(Some(*v)), - GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.clone())), + GroupByScalar::Utf8(v) => ScalarValue::Utf8(Some(v.to_string())), } } } @@ -131,4 +131,9 @@ mod tests { Ok(()) } + + #[test] + fn size_of_group_by_scalar() { + assert_eq!(std::mem::size_of::(), 16); + } } diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs index 6be23263c04..68daad92a87 100644 --- a/rust/datafusion/src/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs @@ -250,6 +250,8 @@ fn group_aggregate_batch( key.push(GroupByScalar::UInt32(0)); } + let mut key = key.into_boxed_slice(); + // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist // 1.3 add the row' index to `indices` @@ -270,7 +272,7 @@ fn group_aggregate_batch( .or_insert_with(|| { // We can safely unwrap here as we checked we can create an accumulator before let accumulator_set = create_accumulators(aggr_expr).unwrap(); - (key.clone(), (accumulator_set, Box::new(vec![row as u32]))) + (key.clone(), (accumulator_set, vec![row as u32])) }); } @@ -296,7 +298,7 @@ fn group_aggregate_batch( // 2.3 compute::take( array, - &UInt32Array::from(*indices.clone()), + &UInt32Array::from(indices.clone()), None, // None: no index check ) .unwrap() @@ -389,7 +391,7 @@ impl GroupedHashAggregateStream { type AccumulatorSet = Vec>; type Accumulators = - HashMap, (AccumulatorSet, Box>), RandomState>; + HashMap, (AccumulatorSet, Vec), RandomState>; impl Stream for GroupedHashAggregateStream { type Item = ArrowResult; @@ -658,7 +660,9 @@ fn create_batch_from_map( GroupByScalar::UInt16(n) => Arc::new(UInt16Array::from(vec![*n])), GroupByScalar::UInt32(n) => Arc::new(UInt32Array::from(vec![*n])), GroupByScalar::UInt64(n) => Arc::new(UInt64Array::from(vec![*n])), - GroupByScalar::Utf8(str) => Arc::new(StringArray::from(vec![&**str])), + GroupByScalar::Utf8(str) => { + Arc::new(StringArray::from(vec![&***str])) + } }) .collect::>(); @@ -726,7 +730,7 @@ fn finalize_aggregation( pub(crate) fn create_key( group_by_keys: &[ArrayRef], row: usize, - vec: &mut Vec, + vec: &mut Box<[GroupByScalar]>, ) -> Result<()> { for i in 0..group_by_keys.len() { let col = &group_by_keys[i]; @@ -765,7 +769,7 @@ pub(crate) fn create_key( } DataType::Utf8 => { let array = col.as_any().downcast_ref::().unwrap(); - vec[i] = GroupByScalar::Utf8(String::from(array.value(row))) + vec[i] = GroupByScalar::Utf8(Box::new(array.value(row).into())) } _ => { // This is internal because we should have caught this before. diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index d2bb8cf7c41..971968a4efd 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -52,7 +52,7 @@ type JoinIndex = Option<(usize, usize)>; // Maps ["on" value] -> [list of indices with this key's value] // E.g. [1, 2] -> [(0, 3), (1, 6), (0, 8)] indicates that (column1, column2) = [1, 2] is true // for rows 3 and 8 from batch 0 and row 6 from batch 1. -type JoinHashMap = HashMap, Vec, RandomState>; +type JoinHashMap = HashMap, Vec, RandomState>; type JoinLeftData = (JoinHashMap, Vec); /// join execution plan executes partitions in parallel and combines them into a set of @@ -209,6 +209,8 @@ fn update_hash( key.push(GroupByScalar::UInt32(0)); } + let mut key = key.into_boxed_slice(); + // update the hash map for row in 0..batch.num_rows() { create_key(&keys_values, row, &mut key)?; @@ -368,8 +370,9 @@ fn build_join_indexes( JoinType::Inner => { // inner => key intersection // unfortunately rust does not support intersection of map keys :( - let left_set: HashSet> = left.keys().cloned().collect(); - let left_right: HashSet> = right.keys().cloned().collect(); + let left_set: HashSet> = left.keys().cloned().collect(); + let left_right: HashSet> = + right.keys().cloned().collect(); let inner = left_set.intersection(&left_right); let mut indexes = Vec::new(); // unknown a prior size