Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1256,15 +1256,7 @@ impl DefaultPhysicalPlanner {
Arc::new(CrossJoinExec::new(physical_left, physical_right))
} else if num_range_filters == 1
&& total_filters == 1
&& !matches!(
join_type,
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftMark
| JoinType::RightMark
)
&& !matches!(join_type, JoinType::LeftMark | JoinType::RightMark)
&& session_state
.config_options()
.optimizer
Expand Down Expand Up @@ -1366,7 +1358,6 @@ impl DefaultPhysicalPlanner {
(on_left, on_right),
op,
*join_type,
session_state.config().target_partitions(),
)?)
} else {
// there is no equal join condition, use the nested loop join
Expand Down
108 changes: 62 additions & 46 deletions datafusion/physical-optimizer/src/join_selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use datafusion_physical_plan::execution_plan::EmissionType;
use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::{
CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode,
StreamJoinPartitionMode, SymmetricHashJoinExec,
PiecewiseMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties};
use std::sync::Arc;
Expand Down Expand Up @@ -256,59 +256,75 @@ fn statistical_join_selection_subrule(
collect_threshold_byte_size: usize,
collect_threshold_num_rows: usize,
) -> Result<Transformed<Arc<dyn ExecutionPlan>>> {
let transformed =
if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() {
match hash_join.partition_mode() {
PartitionMode::Auto => try_collect_left(
hash_join,
false,
collect_threshold_byte_size,
collect_threshold_num_rows,
)?
let transformed = if let Some(hash_join) =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why these showed a diff but there isn't anything different here except for the addition of the PiecewiseMergeJoin branch for swapping.

plan.as_any().downcast_ref::<HashJoinExec>()
{
match hash_join.partition_mode() {
PartitionMode::Auto => try_collect_left(
hash_join,
false,
collect_threshold_byte_size,
collect_threshold_num_rows,
)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)?
.map_or_else(
|| partitioned_hash_join(hash_join).map(Some),
|v| Ok(Some(v)),
)?,
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if hash_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
hash_join
.swap_inputs(PartitionMode::Partitioned)
.map(Some)?
} else {
None
}
PartitionMode::Partitioned => {
let left = hash_join.left();
let right = hash_join.right();
if hash_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
hash_join
.swap_inputs(PartitionMode::Partitioned)
.map(Some)?
} else {
None
}
}
} else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() {
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right)? {
cross_join.swap_inputs().map(Some)?
} else {
None
}
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if nl_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
nl_join.swap_inputs().map(Some)?
} else {
None
}
}
} else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() {
let left = cross_join.left();
let right = cross_join.right();
if should_swap_join_order(&**left, &**right)? {
cross_join.swap_inputs().map(Some)?
} else {
None
};
}
} else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() {
let left = nl_join.left();
let right = nl_join.right();
if nl_join.join_type().supports_swap()
&& should_swap_join_order(&**left, &**right)?
{
nl_join.swap_inputs().map(Some)?
} else {
None
}
} else if let Some(pwmj) = plan.as_any().downcast_ref::<PiecewiseMergeJoinExec>() {
let left = pwmj.buffered();
let right = pwmj.streamed();
if pwmj.join_type().supports_swap()
// Put ! here because should_swap_join_order returns true if left > right but
// PiecewiseMergeJoin wants the left side to be the larger one, so only swap if
// left < right
&& (!should_swap_join_order(&**left, &**right)?
|| matches!(pwmj.join_type(), JoinType::RightSemi | JoinType::RightAnti))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LeftSemi and Left Anti do not swap and only right existence joins do. This is explained above swap_inputs in PiecewiseMergeJoinExec

&& !matches!(pwmj.join_type(), JoinType::LeftSemi | JoinType::LeftAnti)
{
pwmj.swap_inputs().map(Some)?
} else {
None
}
} else {
None
};

Ok(if let Some(transformed) = transformed {
Transformed::yes(transformed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,19 @@ use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final;
use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap};
use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult};

pub(super) enum PiecewiseMergeJoinStreamState {
pub(super) enum ClassicPWMJStreamState {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to make it more distinct between Existence join streamstate and Classic join stream state.

WaitBufferedSide,
FetchStreamBatch,
ProcessStreamBatch(SortedStreamBatch),
ProcessUnmatched,
Completed,
}

impl PiecewiseMergeJoinStreamState {
impl ClassicPWMJStreamState {
// Grab mutable reference to the current stream batch
fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> {
match self {
PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state),
ClassicPWMJStreamState::ProcessStreamBatch(state) => Ok(state),
_ => internal_err!("Expected streamed batch in StreamBatch"),
}
}
Expand Down Expand Up @@ -103,7 +103,7 @@ pub(super) struct ClassicPWMJStream {
// Buffered side data
buffered_side: BufferedSide,
// Tracks the state of the `PiecewiseMergeJoin`
state: PiecewiseMergeJoinStreamState,
state: ClassicPWMJStreamState,
// Sort option for streamed side (specifies whether
// the sort is ascending or descending)
sort_option: SortOptions,
Expand All @@ -119,7 +119,7 @@ impl RecordBatchStream for ClassicPWMJStream {
}
}

// `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`,
// `ClassicPWMJStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`,
// `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`.
//
// Classic Joins
Expand All @@ -140,7 +140,7 @@ impl ClassicPWMJStream {
operator: Operator,
streamed: SendableRecordBatchStream,
buffered_side: BufferedSide,
state: PiecewiseMergeJoinStreamState,
state: ClassicPWMJStreamState,
sort_option: SortOptions,
join_metrics: BuildProbeJoinMetrics,
batch_size: usize,
Expand All @@ -166,19 +166,19 @@ impl ClassicPWMJStream {
) -> Poll<Option<Result<RecordBatch>>> {
loop {
return match self.state {
PiecewiseMergeJoinStreamState::WaitBufferedSide => {
ClassicPWMJStreamState::WaitBufferedSide => {
handle_state!(ready!(self.collect_buffered_side(cx)))
}
PiecewiseMergeJoinStreamState::FetchStreamBatch => {
ClassicPWMJStreamState::FetchStreamBatch => {
handle_state!(ready!(self.fetch_stream_batch(cx)))
}
PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => {
ClassicPWMJStreamState::ProcessStreamBatch(_) => {
handle_state!(self.process_stream_batch())
}
PiecewiseMergeJoinStreamState::ProcessUnmatched => {
ClassicPWMJStreamState::ProcessUnmatched => {
handle_state!(self.process_unmatched_buffered_batch())
}
PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None),
ClassicPWMJStreamState::Completed => Poll::Ready(None),
};
}
}
Expand All @@ -197,7 +197,7 @@ impl ClassicPWMJStream {
build_timer.done();

// We will start fetching stream batches for classic joins
self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch;
self.state = ClassicPWMJStreamState::FetchStreamBatch;

self.buffered_side =
BufferedSide::Ready(BufferedSideReadyState { buffered_data });
Expand All @@ -221,9 +221,9 @@ impl ClassicPWMJStream {
== 1
{
self.batch_process_state.reset();
self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched;
self.state = ClassicPWMJStreamState::ProcessUnmatched;
} else {
self.state = PiecewiseMergeJoinStreamState::Completed;
self.state = ClassicPWMJStreamState::Completed;
}
}
Some(Ok(batch)) => {
Expand All @@ -247,12 +247,11 @@ impl ClassicPWMJStream {

// Reset BatchProcessState before processing a new stream batch
self.batch_process_state.reset();
self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch(
SortedStreamBatch {
self.state =
ClassicPWMJStreamState::ProcessStreamBatch(SortedStreamBatch {
batch: stream_batch,
compare_key_values: vec![stream_values],
},
);
});
}
Some(Err(err)) => return Poll::Ready(Err(err)),
};
Expand Down Expand Up @@ -297,13 +296,13 @@ impl ClassicPWMJStream {
.output_batches
.next_completed_batch()
{
self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch;
self.state = ClassicPWMJStreamState::FetchStreamBatch;
return Ok(StatefulStreamResult::Ready(Some(b)));
}

// Nothing pending; hand back whatever `resolve` returned (often empty) and move on.
if self.batch_process_state.output_batches.is_empty() {
self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch;
self.state = ClassicPWMJStreamState::FetchStreamBatch;

return Ok(StatefulStreamResult::Ready(Some(batch)));
}
Expand All @@ -318,7 +317,7 @@ impl ClassicPWMJStream {
) -> Result<StatefulStreamResult<Option<RecordBatch>>> {
// Return early for `JoinType::Right` and `JoinType::Inner`
if matches!(self.join_type, JoinType::Right | JoinType::Inner) {
self.state = PiecewiseMergeJoinStreamState::Completed;
self.state = ClassicPWMJStreamState::Completed;
return Ok(StatefulStreamResult::Ready(None));
}

Expand All @@ -339,7 +338,7 @@ impl ClassicPWMJStream {
.output_batches
.next_completed_batch()
{
self.state = PiecewiseMergeJoinStreamState::Completed;
self.state = ClassicPWMJStreamState::Completed;
return Ok(StatefulStreamResult::Ready(Some(batch)));
}
}
Expand Down Expand Up @@ -387,11 +386,11 @@ impl ClassicPWMJStream {
.output_batches
.next_completed_batch()
{
self.state = PiecewiseMergeJoinStreamState::Completed;
self.state = ClassicPWMJStreamState::Completed;
return Ok(StatefulStreamResult::Ready(Some(batch)));
}

self.state = PiecewiseMergeJoinStreamState::Completed;
self.state = ClassicPWMJStreamState::Completed;
self.batch_process_state.reset();
Ok(StatefulStreamResult::Ready(None))
}
Expand Down Expand Up @@ -743,7 +742,7 @@ mod tests {
operator: Operator,
join_type: JoinType,
) -> Result<PiecewiseMergeJoinExec> {
PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1)
PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type)
}

async fn join_collect(
Expand Down
Loading
Loading