Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,26 @@ use arrow_buffer::BooleanBuffer;
use datafusion_expr::Operator;
use datafusion_physical_expr_common::datum::compare_op_for_nested;
use futures::{ready, Stream, StreamExt, TryStreamExt};
use log::debug;
use parking_lot::Mutex;

pub const RANDOM_STATE: RandomState = RandomState::with_seeds(0, 0, 0, 0);

#[derive(Default)]
pub struct JoinContext {
build_state: Mutex<Option<Arc<JoinLeftData>>>,
}

impl JoinContext {
pub fn set_build_state(&self, state: Arc<JoinLeftData>) {
self.build_state.lock().replace(state);
}

pub fn get_build_state(&self) -> Option<Arc<JoinLeftData>> {
self.build_state.lock().clone()
}
}

pub struct SharedJoinState {
state_impl: Arc<dyn SharedJoinStateImpl>,
}
Expand Down Expand Up @@ -128,7 +146,7 @@ pub trait SharedJoinStateImpl: Send + Sync + 'static {
type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;

/// HashTable and input data for the left (build side) of a join
struct JoinLeftData {
pub struct JoinLeftData {
/// The hash table with indices into `batch`
hash_map: JoinHashMap,
/// The input rows for the build side
Expand Down Expand Up @@ -165,6 +183,10 @@ impl JoinLeftData {
}
}

pub fn contains_hash(&self, hash: u64) -> bool {
self.hash_map.contains_hash(hash)
}

/// return a reference to the hash map
fn hash_map(&self) -> &JoinHashMap {
&self.hash_map
Expand Down Expand Up @@ -768,6 +790,7 @@ impl ExecutionPlan for HashJoinExec {

let distributed_state =
context.session_config().get_extension::<SharedJoinState>();
let join_context = context.session_config().get_extension::<JoinContext>();

let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
let left_fut = match self.mode {
Expand Down Expand Up @@ -855,6 +878,7 @@ impl ExecutionPlan for HashJoinExec {
batch_size,
hashes_buffer: vec![],
right_side_ordered: self.right.output_ordering().is_some(),
join_context,
}))
}

Expand Down Expand Up @@ -1187,6 +1211,7 @@ struct HashJoinStream {
hashes_buffer: Vec<u64>,
/// Specifies whether the right side has an ordering to potentially preserve
right_side_ordered: bool,
join_context: Option<Arc<JoinContext>>,
}

impl RecordBatchStream for HashJoinStream {
Expand Down Expand Up @@ -1399,6 +1424,11 @@ impl HashJoinStream {
.get_shared(cx))?;
build_timer.done();

if let Some(ctx) = self.join_context.as_ref() {
debug!("setting join left data in join context");
ctx.set_build_state(Arc::clone(&left_data));
}

self.state = HashJoinStreamState::FetchProbeBatch;
self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });

Expand Down
5 changes: 4 additions & 1 deletion datafusion/physical-plan/src/joins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

pub use cross_join::CrossJoinExec;
pub use hash_join::{
HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState,
HashJoinExec, JoinContext, JoinLeftData, SharedJoinState, SharedJoinStateImpl,
SharedProbeState, RANDOM_STATE,
};
pub use nested_loop_join::NestedLoopJoinExec;
// Note: SortMergeJoin is not used in plans yet
Expand All @@ -33,6 +34,8 @@ mod stream_join_utils;
mod symmetric_hash_join;
pub mod utils;

pub type RandomState = ahash::RandomState;

#[cfg(test)]
pub mod test_utils;

Expand Down
4 changes: 4 additions & 0 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ impl JoinHashMap {
next: vec![0; capacity],
}
}

pub fn contains_hash(&self, hash: u64) -> bool {
self.map.find(hash, |(h, _)| *h == hash).is_some()
}
}

// Type of offsets for obtaining indices from JoinHashMap.
Expand Down
Loading