From 0a99e70c38322e456f0262a3d4bcf25972f47259 Mon Sep 17 00:00:00 2001 From: Adam Curtis Date: Tue, 16 Dec 2025 12:21:06 -0500 Subject: [PATCH] fix: wrap join operators with cooperative() for cancellation support --- datafusion/physical-plan/src/joins/cross_join.rs | 9 +++++---- datafusion/physical-plan/src/joins/hash_join/exec.rs | 5 +++-- datafusion/physical-plan/src/joins/nested_loop_join.rs | 5 +++-- .../physical-plan/src/joins/sort_merge_join/exec.rs | 5 +++-- .../physical-plan/src/joins/symmetric_hash_join.rs | 9 +++++---- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/datafusion/physical-plan/src/joins/cross_join.rs b/datafusion/physical-plan/src/joins/cross_join.rs index 7e43deb9bfdf5..41b462be2ca02 100644 --- a/datafusion/physical-plan/src/joins/cross_join.rs +++ b/datafusion/physical-plan/src/joins/cross_join.rs @@ -25,6 +25,7 @@ use super::utils::{ OnceAsync, OnceFut, StatefulStreamResult, adjust_right_output_partitioning, reorder_output_after_swap, }; +use crate::coop::cooperative; use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; use crate::projection::{ @@ -332,7 +333,7 @@ impl ExecutionPlan for CrossJoinExec { })?; if enforce_batch_size_in_joins { - Ok(Box::pin(CrossJoinStream { + Ok(Box::pin(cooperative(CrossJoinStream { schema: Arc::clone(&self.schema), left_fut, right: stream, @@ -341,9 +342,9 @@ impl ExecutionPlan for CrossJoinExec { state: CrossJoinStreamState::WaitBuildSide, left_data: RecordBatch::new_empty(self.left().schema()), batch_transformer: BatchSplitter::new(batch_size), - })) + }))) } else { - Ok(Box::pin(CrossJoinStream { + Ok(Box::pin(cooperative(CrossJoinStream { schema: Arc::clone(&self.schema), left_fut, right: stream, @@ -352,7 +353,7 @@ impl ExecutionPlan for CrossJoinExec { state: CrossJoinStreamState::WaitBuildSide, left_data: RecordBatch::new_empty(self.left().schema()), batch_transformer: NoopBatchTransformer::new(), - })) + }))) } } diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 26447847631fa..6bcd02be54d1f 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -22,6 +22,7 @@ use std::sync::{Arc, OnceLock}; use std::{any::Any, vec}; use crate::ExecutionPlanProperties; +use crate::coop::cooperative; use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, @@ -1034,7 +1035,7 @@ impl ExecutionPlan for HashJoinExec { .map(|(_, right_expr)| Arc::clone(right_expr)) .collect::>(); - Ok(Box::pin(HashJoinStream::new( + Ok(Box::pin(cooperative(HashJoinStream::new( partition, self.schema(), on_right, @@ -1052,7 +1053,7 @@ impl ExecutionPlan for HashJoinExec { self.right.output_ordering().is_some(), build_accumulator, self.mode, - ))) + )))) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index c11ee0f378942..047ceb6074eee 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -29,6 +29,7 @@ use super::utils::{ reorder_output_after_swap, swap_join_projection, }; use crate::common::can_project; +use crate::coop::cooperative; use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::joins::SharedBitmapBuilder; use crate::joins::utils::{ @@ -529,7 +530,7 @@ impl ExecutionPlan for NestedLoopJoinExec { None => self.column_indices.clone(), }; - Ok(Box::pin(NestedLoopJoinStream::new( + Ok(Box::pin(cooperative(NestedLoopJoinStream::new( self.schema(), self.filter.clone(), self.join_type, @@ -538,7 +539,7 @@ impl ExecutionPlan for NestedLoopJoinExec { column_indices_after_projection, metrics, batch_size, - ))) + )))) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs index 5362259d22ea8..37b4264d0eae6 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/exec.rs @@ -23,6 +23,7 @@ use std::any::Any; use std::fmt::Formatter; use std::sync::Arc; +use crate::coop::cooperative; use crate::execution_plan::{EmissionType, boundedness_from_children}; use crate::expressions::PhysicalSortExpr; use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics; @@ -496,7 +497,7 @@ impl ExecutionPlan for SortMergeJoinExec { .register(context.memory_pool()); // create join stream - Ok(Box::pin(SortMergeJoinStream::try_new( + Ok(Box::pin(cooperative(SortMergeJoinStream::try_new( context.session_config().spill_compression(), Arc::clone(&self.schema), self.sort_options.clone(), @@ -511,7 +512,7 @@ impl ExecutionPlan for SortMergeJoinExec { SortMergeJoinMetrics::new(partition, &self.metrics), reservation, context.runtime_env(), - )?)) + )?))) } fn metrics(&self) -> Option { diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs index 1f6bc703a0300..a54f930114c60 100644 --- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs +++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs @@ -33,6 +33,7 @@ use std::task::{Context, Poll}; use std::vec; use crate::common::SharedMemoryReservation; +use crate::coop::cooperative; use crate::execution_plan::{boundedness_from_children, emission_type_from_children}; use crate::joins::stream_join_utils::{ PruningJoinHashMap, SortedFilterExpr, StreamJoinMetrics, @@ -534,7 +535,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { } if enforce_batch_size_in_joins { - Ok(Box::pin(SymmetricHashJoinStream { + Ok(Box::pin(cooperative(SymmetricHashJoinStream { left_stream, right_stream, schema: self.schema(), @@ -552,9 +553,9 @@ impl ExecutionPlan for SymmetricHashJoinExec { state: SHJStreamState::PullRight, reservation, batch_transformer: BatchSplitter::new(batch_size), - })) + }))) } else { - Ok(Box::pin(SymmetricHashJoinStream { + Ok(Box::pin(cooperative(SymmetricHashJoinStream { left_stream, right_stream, schema: self.schema(), @@ -572,7 +573,7 @@ impl ExecutionPlan for SymmetricHashJoinExec { state: SHJStreamState::PullRight, reservation, batch_transformer: NoopBatchTransformer::new(), - })) + }))) } }