diff --git a/datafusion/physical-plan/src/joins/hash_join.rs b/datafusion/physical-plan/src/joins/hash_join.rs index 1dfc543ebb69..9c59265b3b4c 100644 --- a/datafusion/physical-plan/src/joins/hash_join.rs +++ b/datafusion/physical-plan/src/joins/hash_join.rs @@ -77,8 +77,26 @@ use arrow_buffer::BooleanBuffer; use datafusion_expr::Operator; use datafusion_physical_expr_common::datum::compare_op_for_nested; use futures::{ready, Stream, StreamExt, TryStreamExt}; +use log::debug; use parking_lot::Mutex; +pub const RANDOM_STATE: RandomState = RandomState::with_seeds(0, 0, 0, 0); + +#[derive(Default)] +pub struct JoinContext { + build_state: Mutex>>, +} + +impl JoinContext { + pub fn set_build_state(&self, state: Arc) { + self.build_state.lock().replace(state); + } + + pub fn get_build_state(&self) -> Option> { + self.build_state.lock().clone() + } +} + pub struct SharedJoinState { state_impl: Arc, } @@ -128,7 +146,7 @@ pub trait SharedJoinStateImpl: Send + Sync + 'static { type SharedBitmapBuilder = Mutex; /// HashTable and input data for the left (build side) of a join -struct JoinLeftData { +pub struct JoinLeftData { /// The hash table with indices into `batch` hash_map: JoinHashMap, /// The input rows for the build side @@ -165,6 +183,10 @@ impl JoinLeftData { } } + pub fn contains_hash(&self, hash: u64) -> bool { + self.hash_map.contains_hash(hash) + } + /// return a reference to the hash map fn hash_map(&self) -> &JoinHashMap { &self.hash_map @@ -768,6 +790,7 @@ impl ExecutionPlan for HashJoinExec { let distributed_state = context.session_config().get_extension::(); + let join_context = context.session_config().get_extension::(); let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics); let left_fut = match self.mode { @@ -855,6 +878,7 @@ impl ExecutionPlan for HashJoinExec { batch_size, hashes_buffer: vec![], right_side_ordered: self.right.output_ordering().is_some(), + join_context, })) } @@ -1187,6 +1211,7 @@ struct HashJoinStream { hashes_buffer: Vec, /// Specifies whether the right side has an ordering to potentially preserve right_side_ordered: bool, + join_context: Option>, } impl RecordBatchStream for HashJoinStream { @@ -1399,6 +1424,11 @@ impl HashJoinStream { .get_shared(cx))?; build_timer.done(); + if let Some(ctx) = self.join_context.as_ref() { + debug!("setting join left data in join context"); + ctx.set_build_state(Arc::clone(&left_data)); + } + self.state = HashJoinStreamState::FetchProbeBatch; self.build_side = BuildSide::Ready(BuildSideReadyState { left_data }); diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 221f664f0e34..82cb5100d3ae 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -19,7 +19,8 @@ pub use cross_join::CrossJoinExec; pub use hash_join::{ - HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState, + HashJoinExec, JoinContext, JoinLeftData, SharedJoinState, SharedJoinStateImpl, + SharedProbeState, RANDOM_STATE, }; pub use nested_loop_join::NestedLoopJoinExec; // Note: SortMergeJoin is not used in plans yet @@ -33,6 +34,8 @@ mod stream_join_utils; mod symmetric_hash_join; pub mod utils; +pub type RandomState = ahash::RandomState; + #[cfg(test)] pub mod test_utils; diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 17a32a67c743..a73154d64f5f 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -139,6 +139,10 @@ impl JoinHashMap { next: vec![0; capacity], } } + + pub fn contains_hash(&self, hash: u64) -> bool { + self.map.find(hash, |(h, _)| *h == hash).is_some() + } } // Type of offsets for obtaining indices from JoinHashMap.