diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 99708249fc6a7..77d86d4b27bd9 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -21,12 +21,8 @@ use ahash::RandomState; use arrow::{ - array::{ - ArrayData, ArrayRef, BooleanArray, LargeStringArray, PrimitiveArray, - UInt32BufferBuilder, UInt32Builder, UInt64BufferBuilder, UInt64Builder, - }, + array::{ArrayRef, BooleanArray, LargeStringArray, UInt32Builder, UInt64Builder}, compute, - datatypes::{UInt32Type, UInt64Type}, }; use smallvec::{smallvec, SmallVec}; use std::sync::Arc; @@ -674,8 +670,8 @@ fn build_join_indexes( match join_type { JoinType::Inner | JoinType::Semi | JoinType::Anti => { // Using a buffer builder to avoid slower normal builder - let mut left_indices = UInt64BufferBuilder::new(0); - let mut right_indices = UInt32BufferBuilder::new(0); + let mut left_indices = vec![]; + let mut right_indices = vec![]; // Visit all of the right rows for (row, hash_value) in hash_values.iter().enumerate() { @@ -687,48 +683,31 @@ fn build_join_indexes( if let Some((_, indices)) = left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { - for &i in indices { - // Check hash collisions - if equal_rows(i as usize, row, &left_join_values, &keys_values)? { - left_indices.append(i); - right_indices.append(row as u32); - } - } + left_indices.extend(indices); + + right_indices + .extend(std::iter::repeat(row as u32).take(indices.len())); } } - let left = ArrayData::builder(DataType::UInt64) - .len(left_indices.len()) - .add_buffer(left_indices.finish()) - .build(); - let right = ArrayData::builder(DataType::UInt32) - .len(right_indices.len()) - .add_buffer(right_indices.finish()) - .build(); - - Ok(( - PrimitiveArray::::from(left), - PrimitiveArray::::from(right), - )) + + equal_array_rows(left_indices, right_indices, &left_join_values, &keys_values) } JoinType::Left => { - let mut left_indices = UInt64Builder::new(0); - let mut right_indices = UInt32Builder::new(0); + let mut left_indices = vec![]; + let mut right_indices = vec![]; // First visit all of the rows for (row, hash_value) in hash_values.iter().enumerate() { if let Some((_, indices)) = left.0.get(*hash_value, |(hash, _)| *hash_value == *hash) { - for &i in indices { - // Collision check - if equal_rows(i as usize, row, &left_join_values, &keys_values)? { - left_indices.append_value(i)?; - right_indices.append_value(row as u32)?; - } - } + left_indices.extend(indices); + + right_indices + .extend(std::iter::repeat(row as u32).take(indices.len())); }; } - Ok((left_indices.finish(), right_indices.finish())) + equal_array_rows(left_indices, right_indices, &left_join_values, &keys_values) } JoinType::Right | JoinType::Full => { let mut left_indices = UInt64Builder::new(0); @@ -820,6 +799,150 @@ fn equal_rows( err.unwrap_or(Ok(res)) } +// Check whether left and right values are equal +// Returns the indices that matched on the left/right side +fn equal_array_rows( + left: Vec, + right: Vec, + left_arrays: &[ArrayRef], + right_arrays: &[ArrayRef], +) -> Result<(UInt64Array, UInt32Array)> { + // TODO optimize single left/right array + // TODO optimize non-null arrays + let mut res = left_arrays + .iter() + .zip(right_arrays) + .map(|(left_array, right_array)| match left_array.data_type() { + DataType::UInt32 => { + let left_array = + left_array.as_any().downcast_ref::().unwrap(); + let right_array = + right_array.as_any().downcast_ref::().unwrap(); + left.iter() + .zip(right.iter()) + .map(|(l, r)| { + Some( + !left_array.is_null(*l as usize) + && !right_array.is_null(*r as usize) + && left_array.value(*l as usize) + == right_array.value(*r as usize), + ) + }) + .collect::() + } + DataType::UInt64 => { + let left_array = + left_array.as_any().downcast_ref::().unwrap(); + let right_array = + right_array.as_any().downcast_ref::().unwrap(); + left.iter() + .zip(right.iter()) + .map(|(l, r)| { + Some( + !left_array.is_null(*l as usize) + && !right_array.is_null(*r as usize) + && left_array.value(*l as usize) + == right_array.value(*r as usize), + ) + }) + .collect::() + } + DataType::Int64 => { + let left_array = + left_array.as_any().downcast_ref::().unwrap(); + let right_array = + right_array.as_any().downcast_ref::().unwrap(); + left.iter() + .zip(right.iter()) + .map(|(l, r)| { + Some( + !left_array.is_null(*l as usize) + && !right_array.is_null(*r as usize) + && left_array.value(*l as usize) + == right_array.value(*r as usize), + ) + }) + .collect::() + } + + DataType::Int32 => { + let left_array = + left_array.as_any().downcast_ref::().unwrap(); + let right_array = + right_array.as_any().downcast_ref::().unwrap(); + left.iter() + .zip(right.iter()) + .map(|(l, r)| { + Some( + !left_array.is_null(*l as usize) + && !right_array.is_null(*r as usize) + && left_array.value(*l as usize) + == right_array.value(*r as usize), + ) + }) + .collect::() + } + DataType::Utf8 => { + let left_array = + left_array.as_any().downcast_ref::().unwrap(); + let right_array = + right_array.as_any().downcast_ref::().unwrap(); + left.iter() + .zip(right.iter()) + .map(|(l, r)| { + Some( + !left_array.is_null(*l as usize) + && !right_array.is_null(*r as usize) + && left_array.value(*l as usize) + == right_array.value(*r as usize), + ) + }) + .collect::() + } + DataType::LargeUtf8 => { + let left_array = left_array + .as_any() + .downcast_ref::() + .unwrap(); + let right_array = right_array + .as_any() + .downcast_ref::() + .unwrap(); + left.iter() + .zip(right.iter()) + .map(|(l, r)| { + Some( + !left_array.is_null(*l as usize) + && !right_array.is_null(*r as usize) + && left_array.value(*l as usize) + == right_array.value(*r as usize), + ) + }) + .collect::() + } + + _ => { + panic!("") + } + }) + .collect::>(); + // combine masks + let mut mask = res.drain(0..1).next().unwrap(); + for m in res { + mask = compute::and(&mask, &m)?; + } + + let (lefts, rights): (Vec, Vec) = left + .iter() + .zip(right.iter()) + .zip(mask.iter()) + .filter(|(_, s)| *s == Some(true)) + .map(|((l, r), _)| (*l, *r)) + .unzip(); + + Ok((UInt64Array::from(lefts), UInt32Array::from(rights))) +} + // Produces a batch for left-side rows that have/have not been matched during the whole join fn produce_from_matched( visited_left_side: &[bool],