diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 25630a9ec8e..7ca769a5303 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -23,11 +23,12 @@ use ahash::RandomState; use arrow::{ array::{ - ArrayRef, BooleanArray, LargeStringArray, TimestampMicrosecondArray, - TimestampNanosecondArray, UInt32Builder, UInt64Builder, + ArrayData, ArrayRef, BooleanArray, LargeStringArray, PrimitiveArray, + TimestampMicrosecondArray, TimestampNanosecondArray, UInt32BufferBuilder, + UInt32Builder, UInt64BufferBuilder, UInt64Builder, }, compute, - datatypes::TimeUnit, + datatypes::{TimeUnit, UInt32Type, UInt64Type}, }; use std::time::Instant; use std::{any::Any, collections::HashSet}; @@ -237,19 +238,26 @@ impl ExecutionPlan for HashJoinExec { // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream // 2. stores the batches in a vector. - let initial = - (JoinHashMap::with_hasher(IdHashBuilder {}), Vec::new(), 0); - let (hashmap, batches, num_rows) = stream + let initial = ( + JoinHashMap::with_hasher(IdHashBuilder {}), + Vec::new(), + 0, + Vec::new(), + ); + let (hashmap, batches, num_rows, _) = stream .try_fold(initial, |mut acc, batch| async { let hash = &mut acc.0; let values = &mut acc.1; let offset = acc.2; + acc.3.clear(); + acc.3.resize(batch.num_rows(), 0); update_hash( &on_left, &batch, hash, offset, &self.random_state, + &mut acc.3, ) .unwrap(); acc.2 += batch.num_rows(); @@ -311,6 +319,7 @@ fn update_hash( hash: &mut JoinHashMap, offset: usize, random_state: &RandomState, + hashes_buffer: &mut Vec, ) -> Result<()> { // evaluate the keys let keys_values = on @@ -319,7 +328,7 @@ fn update_hash( .collect::>>()?; // update the hash map - let hash_values = create_hashes(&keys_values, &random_state)?; + let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?; // insert hashes to key of the hashmap for (row, hash_value) in hash_values.iter().enumerate() { @@ -476,15 +485,16 @@ fn build_join_indexes( .into_array(left_data.1.num_rows())) }) .collect::>>()?; - - let hash_values = create_hashes(&keys_values, &random_state)?; + let hashes_buffer = &mut vec![0; keys_values[0].len()]; + let hash_values = create_hashes(&keys_values, &random_state, hashes_buffer)?; let left = &left_data.0; - let mut left_indices = UInt64Builder::new(0); - let mut right_indices = UInt32Builder::new(0); - match join_type { JoinType::Inner => { + // Using a buffer builder to avoid slower normal builder + let mut left_indices = UInt64BufferBuilder::new(0); + let mut right_indices = UInt32BufferBuilder::new(0); + // Visit all of the right rows for (row, hash_value) in hash_values.iter().enumerate() { // Get the hash and find it in the build index @@ -496,15 +506,30 @@ fn build_join_indexes( for &i in indices { // Check hash collisions if equal_rows(i as usize, row, &left_join_values, &keys_values)? { - left_indices.append_value(i)?; - right_indices.append_value(row as u32)?; + left_indices.append(i); + right_indices.append(row as u32); } } } } - Ok((left_indices.finish(), right_indices.finish())) + let left = ArrayData::builder(DataType::UInt64) + .len(left_indices.len()) + .add_buffer(left_indices.finish()) + .build(); + let right = ArrayData::builder(DataType::UInt32) + .len(right_indices.len()) + .add_buffer(right_indices.finish()) + .build(); + + Ok(( + PrimitiveArray::::from(left), + PrimitiveArray::::from(right), + )) } JoinType::Left => { + let mut left_indices = UInt64Builder::new(0); + let mut right_indices = UInt32Builder::new(0); + // Keep track of which item is visited in the build input // TODO: this can be stored more efficiently with a marker // https://issues.apache.org/jira/browse/ARROW-11116 @@ -534,10 +559,12 @@ fn build_join_indexes( } } } - Ok((left_indices.finish(), right_indices.finish())) } JoinType::Right => { + let mut left_indices = UInt64Builder::new(0); + let mut right_indices = UInt32Builder::new(0); + for (row, hash_value) in hash_values.iter().enumerate() { match left.get(hash_value) { Some(indices) => { @@ -699,50 +726,60 @@ macro_rules! hash_array { } /// Creates hash values for every element in the row based on the values in the columns -pub fn create_hashes( +pub fn create_hashes<'a>( arrays: &[ArrayRef], random_state: &RandomState, -) -> Result> { - let rows = arrays[0].len(); - let mut hashes = vec![0; rows]; - + hashes_buffer: &'a mut Vec, +) -> Result<&'a mut Vec> { for col in arrays { match col.data_type() { DataType::UInt8 => { - hash_array!(UInt8Array, col, u8, hashes, random_state); + hash_array!(UInt8Array, col, u8, hashes_buffer, random_state); } DataType::UInt16 => { - hash_array!(UInt16Array, col, u16, hashes, random_state); + hash_array!(UInt16Array, col, u16, hashes_buffer, random_state); } DataType::UInt32 => { - hash_array!(UInt32Array, col, u32, hashes, random_state); + hash_array!(UInt32Array, col, u32, hashes_buffer, random_state); } DataType::UInt64 => { - hash_array!(UInt64Array, col, u64, hashes, random_state); + hash_array!(UInt64Array, col, u64, hashes_buffer, random_state); } DataType::Int8 => { - hash_array!(Int8Array, col, i8, hashes, random_state); + hash_array!(Int8Array, col, i8, hashes_buffer, random_state); } DataType::Int16 => { - hash_array!(Int16Array, col, i16, hashes, random_state); + hash_array!(Int16Array, col, i16, hashes_buffer, random_state); } DataType::Int32 => { - hash_array!(Int32Array, col, i32, hashes, random_state); + hash_array!(Int32Array, col, i32, hashes_buffer, random_state); } DataType::Int64 => { - hash_array!(Int64Array, col, i64, hashes, random_state); + hash_array!(Int64Array, col, i64, hashes_buffer, random_state); } DataType::Timestamp(TimeUnit::Microsecond, None) => { - hash_array!(TimestampMicrosecondArray, col, i64, hashes, random_state); + hash_array!( + TimestampMicrosecondArray, + col, + i64, + hashes_buffer, + random_state + ); } DataType::Timestamp(TimeUnit::Nanosecond, None) => { - hash_array!(TimestampNanosecondArray, col, i64, hashes, random_state); + hash_array!( + TimestampNanosecondArray, + col, + i64, + hashes_buffer, + random_state + ); } DataType::Boolean => { - hash_array!(BooleanArray, col, u8, hashes, random_state); + hash_array!(BooleanArray, col, u8, hashes_buffer, random_state); } DataType::Utf8 => { - hash_array!(StringArray, col, str, hashes, random_state); + hash_array!(StringArray, col, str, hashes_buffer, random_state); } _ => { // This is internal because we should have caught this before. @@ -752,7 +789,7 @@ pub fn create_hashes( } } } - Ok(hashes) + Ok(hashes_buffer) } impl Stream for HashJoinStream { @@ -1136,8 +1173,9 @@ mod tests { ); let random_state = RandomState::new(); - - let hashes = create_hashes(&[left.columns()[0].clone()], &random_state)?; + let hashes_buff = &mut vec![0; left.num_rows()]; + let hashes = + create_hashes(&[left.columns()[0].clone()], &random_state, hashes_buff)?; // Create hash collisions hashmap_left.insert(hashes[0], vec![0, 1]); diff --git a/rust/datafusion/src/physical_plan/repartition.rs b/rust/datafusion/src/physical_plan/repartition.rs index 94c3aab64e1..16426f246f9 100644 --- a/rust/datafusion/src/physical_plan/repartition.rs +++ b/rust/datafusion/src/physical_plan/repartition.rs @@ -151,7 +151,9 @@ impl ExecutionPlan for RepartitionExec { }) .collect::>>()?; // Hash arrays and compute buckets based on number of partitions - let hashes = create_hashes(&arrays, &random_state)?; + let hashes_buf = &mut vec![0; arrays[0].len()]; + let hashes = + create_hashes(&arrays, &random_state, hashes_buf)?; let mut indices = vec![vec![]; num_output_partitions]; for (index, hash) in hashes.iter().enumerate() { indices