-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat: Add Semi/Anti join to PiecewiseMergeJoin #18392
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
af4f5e7
233d7a3
6861080
ca86eaa
8bd2a1b
9055043
89a8b2b
97e8cfe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,7 @@ use datafusion_physical_plan::execution_plan::EmissionType; | |
| use datafusion_physical_plan::joins::utils::ColumnIndex; | ||
| use datafusion_physical_plan::joins::{ | ||
| CrossJoinExec, HashJoinExec, NestedLoopJoinExec, PartitionMode, | ||
| StreamJoinPartitionMode, SymmetricHashJoinExec, | ||
| PiecewiseMergeJoinExec, StreamJoinPartitionMode, SymmetricHashJoinExec, | ||
| }; | ||
| use datafusion_physical_plan::{ExecutionPlan, ExecutionPlanProperties}; | ||
| use std::sync::Arc; | ||
|
|
@@ -256,59 +256,75 @@ fn statistical_join_selection_subrule( | |
| collect_threshold_byte_size: usize, | ||
| collect_threshold_num_rows: usize, | ||
| ) -> Result<Transformed<Arc<dyn ExecutionPlan>>> { | ||
| let transformed = | ||
| if let Some(hash_join) = plan.as_any().downcast_ref::<HashJoinExec>() { | ||
| match hash_join.partition_mode() { | ||
| PartitionMode::Auto => try_collect_left( | ||
| hash_join, | ||
| false, | ||
| collect_threshold_byte_size, | ||
| collect_threshold_num_rows, | ||
| )? | ||
| let transformed = if let Some(hash_join) = | ||
| plan.as_any().downcast_ref::<HashJoinExec>() | ||
| { | ||
| match hash_join.partition_mode() { | ||
| PartitionMode::Auto => try_collect_left( | ||
| hash_join, | ||
| false, | ||
| collect_threshold_byte_size, | ||
| collect_threshold_num_rows, | ||
| )? | ||
| .map_or_else( | ||
| || partitioned_hash_join(hash_join).map(Some), | ||
| |v| Ok(Some(v)), | ||
| )?, | ||
| PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? | ||
| .map_or_else( | ||
| || partitioned_hash_join(hash_join).map(Some), | ||
| |v| Ok(Some(v)), | ||
| )?, | ||
| PartitionMode::CollectLeft => try_collect_left(hash_join, true, 0, 0)? | ||
| .map_or_else( | ||
| || partitioned_hash_join(hash_join).map(Some), | ||
| |v| Ok(Some(v)), | ||
| )?, | ||
| PartitionMode::Partitioned => { | ||
| let left = hash_join.left(); | ||
| let right = hash_join.right(); | ||
| if hash_join.join_type().supports_swap() | ||
| && should_swap_join_order(&**left, &**right)? | ||
| { | ||
| hash_join | ||
| .swap_inputs(PartitionMode::Partitioned) | ||
| .map(Some)? | ||
| } else { | ||
| None | ||
| } | ||
| PartitionMode::Partitioned => { | ||
| let left = hash_join.left(); | ||
| let right = hash_join.right(); | ||
| if hash_join.join_type().supports_swap() | ||
| && should_swap_join_order(&**left, &**right)? | ||
| { | ||
| hash_join | ||
| .swap_inputs(PartitionMode::Partitioned) | ||
| .map(Some)? | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() { | ||
| let left = cross_join.left(); | ||
| let right = cross_join.right(); | ||
| if should_swap_join_order(&**left, &**right)? { | ||
| cross_join.swap_inputs().map(Some)? | ||
| } else { | ||
| None | ||
| } | ||
| } else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() { | ||
| let left = nl_join.left(); | ||
| let right = nl_join.right(); | ||
| if nl_join.join_type().supports_swap() | ||
| && should_swap_join_order(&**left, &**right)? | ||
| { | ||
| nl_join.swap_inputs().map(Some)? | ||
| } else { | ||
| None | ||
| } | ||
| } | ||
| } else if let Some(cross_join) = plan.as_any().downcast_ref::<CrossJoinExec>() { | ||
| let left = cross_join.left(); | ||
| let right = cross_join.right(); | ||
| if should_swap_join_order(&**left, &**right)? { | ||
| cross_join.swap_inputs().map(Some)? | ||
| } else { | ||
| None | ||
| }; | ||
| } | ||
| } else if let Some(nl_join) = plan.as_any().downcast_ref::<NestedLoopJoinExec>() { | ||
| let left = nl_join.left(); | ||
| let right = nl_join.right(); | ||
| if nl_join.join_type().supports_swap() | ||
| && should_swap_join_order(&**left, &**right)? | ||
| { | ||
| nl_join.swap_inputs().map(Some)? | ||
| } else { | ||
| None | ||
| } | ||
| } else if let Some(pwmj) = plan.as_any().downcast_ref::<PiecewiseMergeJoinExec>() { | ||
| let left = pwmj.buffered(); | ||
| let right = pwmj.streamed(); | ||
| if pwmj.join_type().supports_swap() | ||
| // Put ! here because should_swap_join_order returns true if left > right but | ||
| // PiecewiseMergeJoin wants the left side to be the larger one, so only swap if | ||
| // left < right | ||
| && (!should_swap_join_order(&**left, &**right)? | ||
| || matches!(pwmj.join_type(), JoinType::RightSemi | JoinType::RightAnti)) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LeftSemi and Left Anti do not swap and only right existence joins do. This is explained above |
||
| && !matches!(pwmj.join_type(), JoinType::LeftSemi | JoinType::LeftAnti) | ||
| { | ||
| pwmj.swap_inputs().map(Some)? | ||
| } else { | ||
| None | ||
| } | ||
| } else { | ||
| None | ||
| }; | ||
|
|
||
| Ok(if let Some(transformed) = transformed { | ||
| Transformed::yes(transformed) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,19 +40,19 @@ use crate::joins::piecewise_merge_join::utils::need_produce_result_in_final; | |
| use crate::joins::utils::{compare_join_arrays, get_final_indices_from_shared_bitmap}; | ||
| use crate::joins::utils::{BuildProbeJoinMetrics, StatefulStreamResult}; | ||
|
|
||
| pub(super) enum PiecewiseMergeJoinStreamState { | ||
| pub(super) enum ClassicPWMJStreamState { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed to make it more distinct between Existence join streamstate and Classic join stream state. |
||
| WaitBufferedSide, | ||
| FetchStreamBatch, | ||
| ProcessStreamBatch(SortedStreamBatch), | ||
| ProcessUnmatched, | ||
| Completed, | ||
| } | ||
|
|
||
| impl PiecewiseMergeJoinStreamState { | ||
| impl ClassicPWMJStreamState { | ||
| // Grab mutable reference to the current stream batch | ||
| fn try_as_process_stream_batch_mut(&mut self) -> Result<&mut SortedStreamBatch> { | ||
| match self { | ||
| PiecewiseMergeJoinStreamState::ProcessStreamBatch(state) => Ok(state), | ||
| ClassicPWMJStreamState::ProcessStreamBatch(state) => Ok(state), | ||
| _ => internal_err!("Expected streamed batch in StreamBatch"), | ||
| } | ||
| } | ||
|
|
@@ -103,7 +103,7 @@ pub(super) struct ClassicPWMJStream { | |
| // Buffered side data | ||
| buffered_side: BufferedSide, | ||
| // Tracks the state of the `PiecewiseMergeJoin` | ||
| state: PiecewiseMergeJoinStreamState, | ||
| state: ClassicPWMJStreamState, | ||
| // Sort option for streamed side (specifies whether | ||
| // the sort is ascending or descending) | ||
| sort_option: SortOptions, | ||
|
|
@@ -119,7 +119,7 @@ impl RecordBatchStream for ClassicPWMJStream { | |
| } | ||
| } | ||
|
|
||
| // `PiecewiseMergeJoinStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, | ||
| // `ClassicPWMJStreamState` is separated into `WaitBufferedSide`, `FetchStreamBatch`, | ||
| // `ProcessStreamBatch`, `ProcessUnmatched` and `Completed`. | ||
| // | ||
| // Classic Joins | ||
|
|
@@ -140,7 +140,7 @@ impl ClassicPWMJStream { | |
| operator: Operator, | ||
| streamed: SendableRecordBatchStream, | ||
| buffered_side: BufferedSide, | ||
| state: PiecewiseMergeJoinStreamState, | ||
| state: ClassicPWMJStreamState, | ||
| sort_option: SortOptions, | ||
| join_metrics: BuildProbeJoinMetrics, | ||
| batch_size: usize, | ||
|
|
@@ -166,19 +166,19 @@ impl ClassicPWMJStream { | |
| ) -> Poll<Option<Result<RecordBatch>>> { | ||
| loop { | ||
| return match self.state { | ||
| PiecewiseMergeJoinStreamState::WaitBufferedSide => { | ||
| ClassicPWMJStreamState::WaitBufferedSide => { | ||
| handle_state!(ready!(self.collect_buffered_side(cx))) | ||
| } | ||
| PiecewiseMergeJoinStreamState::FetchStreamBatch => { | ||
| ClassicPWMJStreamState::FetchStreamBatch => { | ||
| handle_state!(ready!(self.fetch_stream_batch(cx))) | ||
| } | ||
| PiecewiseMergeJoinStreamState::ProcessStreamBatch(_) => { | ||
| ClassicPWMJStreamState::ProcessStreamBatch(_) => { | ||
| handle_state!(self.process_stream_batch()) | ||
| } | ||
| PiecewiseMergeJoinStreamState::ProcessUnmatched => { | ||
| ClassicPWMJStreamState::ProcessUnmatched => { | ||
| handle_state!(self.process_unmatched_buffered_batch()) | ||
| } | ||
| PiecewiseMergeJoinStreamState::Completed => Poll::Ready(None), | ||
| ClassicPWMJStreamState::Completed => Poll::Ready(None), | ||
| }; | ||
| } | ||
| } | ||
|
|
@@ -197,7 +197,7 @@ impl ClassicPWMJStream { | |
| build_timer.done(); | ||
|
|
||
| // We will start fetching stream batches for classic joins | ||
| self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; | ||
| self.state = ClassicPWMJStreamState::FetchStreamBatch; | ||
|
|
||
| self.buffered_side = | ||
| BufferedSide::Ready(BufferedSideReadyState { buffered_data }); | ||
|
|
@@ -221,9 +221,9 @@ impl ClassicPWMJStream { | |
| == 1 | ||
| { | ||
| self.batch_process_state.reset(); | ||
| self.state = PiecewiseMergeJoinStreamState::ProcessUnmatched; | ||
| self.state = ClassicPWMJStreamState::ProcessUnmatched; | ||
| } else { | ||
| self.state = PiecewiseMergeJoinStreamState::Completed; | ||
| self.state = ClassicPWMJStreamState::Completed; | ||
| } | ||
| } | ||
| Some(Ok(batch)) => { | ||
|
|
@@ -247,12 +247,11 @@ impl ClassicPWMJStream { | |
|
|
||
| // Reset BatchProcessState before processing a new stream batch | ||
| self.batch_process_state.reset(); | ||
| self.state = PiecewiseMergeJoinStreamState::ProcessStreamBatch( | ||
| SortedStreamBatch { | ||
| self.state = | ||
| ClassicPWMJStreamState::ProcessStreamBatch(SortedStreamBatch { | ||
| batch: stream_batch, | ||
| compare_key_values: vec![stream_values], | ||
| }, | ||
| ); | ||
| }); | ||
| } | ||
| Some(Err(err)) => return Poll::Ready(Err(err)), | ||
| }; | ||
|
|
@@ -297,13 +296,13 @@ impl ClassicPWMJStream { | |
| .output_batches | ||
| .next_completed_batch() | ||
| { | ||
| self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; | ||
| self.state = ClassicPWMJStreamState::FetchStreamBatch; | ||
| return Ok(StatefulStreamResult::Ready(Some(b))); | ||
| } | ||
|
|
||
| // Nothing pending; hand back whatever `resolve` returned (often empty) and move on. | ||
| if self.batch_process_state.output_batches.is_empty() { | ||
| self.state = PiecewiseMergeJoinStreamState::FetchStreamBatch; | ||
| self.state = ClassicPWMJStreamState::FetchStreamBatch; | ||
|
|
||
| return Ok(StatefulStreamResult::Ready(Some(batch))); | ||
| } | ||
|
|
@@ -318,7 +317,7 @@ impl ClassicPWMJStream { | |
| ) -> Result<StatefulStreamResult<Option<RecordBatch>>> { | ||
| // Return early for `JoinType::Right` and `JoinType::Inner` | ||
| if matches!(self.join_type, JoinType::Right | JoinType::Inner) { | ||
| self.state = PiecewiseMergeJoinStreamState::Completed; | ||
| self.state = ClassicPWMJStreamState::Completed; | ||
| return Ok(StatefulStreamResult::Ready(None)); | ||
| } | ||
|
|
||
|
|
@@ -339,7 +338,7 @@ impl ClassicPWMJStream { | |
| .output_batches | ||
| .next_completed_batch() | ||
| { | ||
| self.state = PiecewiseMergeJoinStreamState::Completed; | ||
| self.state = ClassicPWMJStreamState::Completed; | ||
| return Ok(StatefulStreamResult::Ready(Some(batch))); | ||
| } | ||
| } | ||
|
|
@@ -387,11 +386,11 @@ impl ClassicPWMJStream { | |
| .output_batches | ||
| .next_completed_batch() | ||
| { | ||
| self.state = PiecewiseMergeJoinStreamState::Completed; | ||
| self.state = ClassicPWMJStreamState::Completed; | ||
| return Ok(StatefulStreamResult::Ready(Some(batch))); | ||
| } | ||
|
|
||
| self.state = PiecewiseMergeJoinStreamState::Completed; | ||
| self.state = ClassicPWMJStreamState::Completed; | ||
| self.batch_process_state.reset(); | ||
| Ok(StatefulStreamResult::Ready(None)) | ||
| } | ||
|
|
@@ -743,7 +742,7 @@ mod tests { | |
| operator: Operator, | ||
| join_type: JoinType, | ||
| ) -> Result<PiecewiseMergeJoinExec> { | ||
| PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type, 1) | ||
| PiecewiseMergeJoinExec::try_new(left, right, on, operator, join_type) | ||
| } | ||
|
|
||
| async fn join_collect( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why these showed a diff but there isn't anything different here except for the addition of the PiecewiseMergeJoin branch for swapping.