diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs index 2c23a9c8c5d14..af037f5c93748 100644 --- a/datafusion/physical-plan/src/joins/hash_join/exec.rs +++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs @@ -16,6 +16,7 @@ // under the License. use std::fmt; +use std::marker::PhantomData; use std::mem::size_of; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, OnceLock}; @@ -26,7 +27,9 @@ use crate::filter_pushdown::{ ChildPushdownResult, FilterDescription, FilterPushdownPhase, FilterPushdownPropagation, }; -use crate::joins::hash_join::shared_bounds::{ColumnBounds, SharedBoundsAccumulator}; +use crate::joins::hash_join::shared_bounds::{ + ColumnBounds, MinMaxColumnBounds, SharedBoundsAccumulator, +}; use crate::joins::hash_join::stream::{ BuildSide, BuildSideInitialState, HashJoinStream, HashJoinStreamState, }; @@ -103,7 +106,7 @@ pub(super) struct JoinLeftData { /// The MemoryReservation ensures proper tracking of memory resources throughout the join operation's lifecycle. _reservation: MemoryReservation, /// Bounds computed from the build side for dynamic filter pushdown - pub(super) bounds: Option>, + pub(super) bounds: Option>>, } impl JoinLeftData { @@ -115,7 +118,7 @@ impl JoinLeftData { visited_indices_bitmap: SharedBitmapBuilder, probe_threads_counter: AtomicUsize, reservation: MemoryReservation, - bounds: Option>, + bounds: Option>>, ) -> Self { Self { hash_map, @@ -319,7 +322,7 @@ impl JoinLeftData { /// Note this structure includes a [`OnceAsync`] that is used to coordinate the /// loading of the left side with the processing in each output stream. /// Therefore it can not be [`Clone`] -pub struct HashJoinExec { +pub struct HashJoinExec { /// left (build) side which gets hashed pub left: Arc, /// right (probe) side which are filtered by the hash table @@ -358,6 +361,8 @@ pub struct HashJoinExec { /// Set when dynamic filter pushdown is detected in handle_child_pushdown_result. /// HashJoinExec also needs to keep a shared bounds accumulator for coordinating updates. dynamic_filter: Option, + /// Phantom data for the bounds accumulator type + _phantom_accumulator: PhantomData, } #[derive(Clone)] @@ -369,7 +374,7 @@ struct HashJoinExecDynamicFilter { bounds_accumulator: OnceLock>, } -impl fmt::Debug for HashJoinExec { +impl fmt::Debug for HashJoinExec { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("HashJoinExec") .field("left", &self.left) @@ -391,14 +396,14 @@ impl fmt::Debug for HashJoinExec { } } -impl EmbeddedProjection for HashJoinExec { +impl EmbeddedProjection for HashJoinExec { fn with_projection(&self, projection: Option>) -> Result { self.with_projection(projection) } } -impl HashJoinExec { - /// Tries to create a new [HashJoinExec]. +impl HashJoinExec { + /// Tries to create a new [HashJoinExec] with a default `MinMaxLeftAccumulator` bounds accumulator. /// /// # Error /// This function errors when it is not possible to join the left and right sides on keys `on`. @@ -460,6 +465,75 @@ impl HashJoinExec { null_equality, cache, dynamic_filter: None, + _phantom_accumulator: PhantomData, + }) + } +} + +impl HashJoinExec { + /// Tries to create a new [HashJoinExec] with a custom bounds accumulator. + /// + /// # Error + /// This function errors when it is not possible to join the left and right sides on keys `on`. + #[allow(clippy::too_many_arguments)] + pub fn try_new_with_accumulator( + left: Arc, + right: Arc, + on: JoinOn, + filter: Option, + join_type: &JoinType, + projection: Option>, + partition_mode: PartitionMode, + null_equality: NullEquality, + ) -> Result { + let left_schema = left.schema(); + let right_schema = right.schema(); + if on.is_empty() { + return plan_err!("On constraints in HashJoinExec should be non-empty"); + } + + check_join_is_valid(&left_schema, &right_schema, &on)?; + + let (join_schema, column_indices) = + build_join_schema(&left_schema, &right_schema, join_type); + + let random_state = HASH_JOIN_SEED; + + let join_schema = Arc::new(join_schema); + + // check if the projection is valid + can_project(&join_schema, projection.as_ref())?; + + let cache = Self::compute_properties( + &left, + &right, + Arc::clone(&join_schema), + *join_type, + &on, + partition_mode, + projection.as_ref(), + )?; + + // Initialize both dynamic filter and bounds accumulator to None + // They will be set later if dynamic filtering is enabled + + Ok(HashJoinExec { + left, + right, + on, + filter, + join_type: *join_type, + join_schema, + left_fut: Default::default(), + random_state, + mode: partition_mode, + metrics: ExecutionPlanMetricsSet::new(), + projection, + column_indices, + null_equality, + cache, + dynamic_filter: None, + _phantom_accumulator: PhantomData, }) } @@ -549,7 +623,7 @@ impl HashJoinExec { }, None => None, }; - Self::try_new( + Self::try_new_with_accumulator( Arc::clone(&self.left), Arc::clone(&self.right), self.on.clone(), @@ -665,7 +739,7 @@ impl HashJoinExec { ) -> Result> { let left = self.left(); let right = self.right(); - let new_join = HashJoinExec::try_new( + let new_join = HashJoinExec::::try_new_with_accumulator( Arc::clone(right), Arc::clone(left), self.on() @@ -699,7 +773,7 @@ impl HashJoinExec { } } -impl DisplayAs for HashJoinExec { +impl DisplayAs for HashJoinExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { @@ -763,7 +837,7 @@ impl DisplayAs for HashJoinExec { } } -impl ExecutionPlan for HashJoinExec { +impl ExecutionPlan for HashJoinExec { fn name(&self) -> &'static str { "HashJoinExec" } @@ -833,7 +907,7 @@ impl ExecutionPlan for HashJoinExec { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new(HashJoinExec { + Ok(Arc::new(HashJoinExec:: { left: Arc::clone(&children[0]), right: Arc::clone(&children[1]), on: self.on.clone(), @@ -858,11 +932,12 @@ impl ExecutionPlan for HashJoinExec { )?, // Keep the dynamic filter, bounds accumulator will be reset dynamic_filter: self.dynamic_filter.clone(), + _phantom_accumulator: PhantomData, })) } fn reset_state(self: Arc) -> Result> { - Ok(Arc::new(HashJoinExec { + Ok(Arc::new(HashJoinExec:: { left: Arc::clone(&self.left), right: Arc::clone(&self.right), on: self.on.clone(), @@ -880,6 +955,7 @@ impl ExecutionPlan for HashJoinExec { cache: self.cache.clone(), // Reset dynamic filter and bounds accumulator to initial state dynamic_filter: None, + _phantom_accumulator: PhantomData, })) } @@ -921,7 +997,7 @@ impl ExecutionPlan for HashJoinExec { let reservation = MemoryConsumer::new("HashJoinInput").register(context.memory_pool()); - Ok(collect_left_input( + Ok(collect_left_input::( self.random_state.clone(), left_stream, on_left.clone(), @@ -939,7 +1015,7 @@ impl ExecutionPlan for HashJoinExec { MemoryConsumer::new(format!("HashJoinInput[{partition}]")) .register(context.memory_pool()); - OnceFut::new(collect_left_input( + OnceFut::new(collect_left_input::( self.random_state.clone(), left_stream, on_left.clone(), @@ -1162,7 +1238,7 @@ impl ExecutionPlan for HashJoinExec { Arc::downcast::(predicate) { // We successfully pushed down our self filter - we need to make a new node with the dynamic filter - let new_node = Arc::new(HashJoinExec { + let new_node = Arc::new(HashJoinExec:: { left: Arc::clone(&self.left), right: Arc::clone(&self.right), on: self.on.clone(), @@ -1181,6 +1257,7 @@ impl ExecutionPlan for HashJoinExec { filter: dynamic_filter, bounds_accumulator: OnceLock::new(), }), + _phantom_accumulator: PhantomData, }); result = result.with_updated_node(new_node as Arc); } @@ -1189,6 +1266,45 @@ impl ExecutionPlan for HashJoinExec { } } +/// Trait defining an accumulator for collecting build-side data during hash joins. +/// +/// The accumulator is responsible for processing batches of data from the build side, and computing or storing intermediate results needed for dynamic filtering. +/// +/// For example, the [`MinMaxLeftAccumulator`] implementation collects minimum and maximum values for join key expressions across all build-side batches. +pub trait CollectLeftAccumulator: Send + Sync { + /// Creates a new accumulator for the given expression and schema. + /// + /// # Arguments + /// * `expr` - The physical expression to track bounds for + /// * `schema` - The schema of the input data + /// + /// # Returns + /// A new `CollectLeftAccumulator` instance configured for the expression's data type + fn try_new(expr: Arc, schema: &SchemaRef) -> Result + where + Self: Sized; + + /// Updates the accumulator with values from a new batch. + /// + /// Evaluates the expression on the batch and updates both min and max + /// accumulators with the resulting values. + /// + /// # Arguments + /// * `batch` - The record batch to process + /// + /// # Returns + /// Ok(()) if the update succeeds, or an error if updating fails. + fn update_batch(&mut self, batch: &RecordBatch) -> Result<()>; + + /// Finalizes the accumulation and returns the computed bounds. + /// + /// Consumes self to extract the final bounds from the accumulators. + /// + /// # Returns + /// The `ColumnBounds` containing the bounds observed + fn evaluate(self) -> Result>; +} + /// Accumulator for collecting min/max bounds from build-side data during hash join. /// /// This struct encapsulates the logic for progressively computing column bounds @@ -1198,7 +1314,7 @@ impl ExecutionPlan for HashJoinExec { /// The bounds are used for dynamic filter pushdown optimization, where filters /// based on the actual data ranges can be pushed down to the probe side to /// eliminate unnecessary data early. -struct CollectLeftAccumulator { +pub struct MinMaxLeftAccumulator { /// The physical expression to evaluate for each batch expr: Arc, /// Accumulator for tracking the minimum value across all batches @@ -1207,15 +1323,7 @@ struct CollectLeftAccumulator { max: MaxAccumulator, } -impl CollectLeftAccumulator { - /// Creates a new accumulator for tracking bounds of a join key expression. - /// - /// # Arguments - /// * `expr` - The physical expression to track bounds for - /// * `schema` - The schema of the input data - /// - /// # Returns - /// A new `CollectLeftAccumulator` instance configured for the expression's data type +impl CollectLeftAccumulator for MinMaxLeftAccumulator { fn try_new(expr: Arc, schema: &SchemaRef) -> Result { /// Recursively unwraps dictionary types to get the underlying value type. fn dictionary_value_type(data_type: &DataType) -> DataType { @@ -1238,16 +1346,8 @@ impl CollectLeftAccumulator { }) } - /// Updates the accumulators with values from a new batch. - /// /// Evaluates the expression on the batch and updates both min and max /// accumulators with the resulting values. - /// - /// # Arguments - /// * `batch` - The record batch to process - /// - /// # Returns - /// Ok(()) if the update succeeds, or an error if expression evaluation fails fn update_batch(&mut self, batch: &RecordBatch) -> Result<()> { let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; self.min.update_batch(std::slice::from_ref(&array))?; @@ -1255,30 +1355,24 @@ impl CollectLeftAccumulator { Ok(()) } - /// Finalizes the accumulation and returns the computed bounds. - /// - /// Consumes self to extract the final min and max values from the accumulators. - /// - /// # Returns - /// The `ColumnBounds` containing the minimum and maximum values observed - fn evaluate(mut self) -> Result { - Ok(ColumnBounds::new( + fn evaluate(mut self) -> Result> { + Ok(Arc::new(MinMaxColumnBounds::new( self.min.evaluate()?, self.max.evaluate()?, - )) + ))) } } /// State for collecting the build-side data during hash join -struct BuildSideState { +struct BuildSideState { batches: Vec, num_rows: usize, metrics: BuildProbeJoinMetrics, reservation: MemoryReservation, - bounds_accumulators: Option>, + bounds_accumulators: Option>, } -impl BuildSideState { +impl BuildSideState { /// Create a new BuildSideState with optional accumulators for bounds computation fn try_new( metrics: BuildProbeJoinMetrics, @@ -1296,9 +1390,7 @@ impl BuildSideState { .then(|| { on_left .iter() - .map(|expr| { - CollectLeftAccumulator::try_new(Arc::clone(expr), schema) - }) + .map(|expr| A::try_new(Arc::clone(expr), schema)) .collect::>>() }) .transpose()?, @@ -1335,7 +1427,7 @@ impl BuildSideState { /// `JoinLeftData` containing the hash map, consolidated batch, join key values, /// visited indices bitmap, and computed bounds (if requested). #[allow(clippy::too_many_arguments)] -async fn collect_left_input( +async fn collect_left_input( random_state: RandomState, left_stream: SendableRecordBatchStream, on_left: Vec, @@ -1350,7 +1442,7 @@ async fn collect_left_input( // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream // 2. stores the batches in a vector. - let initial = BuildSideState::try_new( + let initial = BuildSideState::::try_new( metrics, reservation, on_left.clone(), @@ -1384,7 +1476,7 @@ async fn collect_left_input( .await?; // Extract fields from state - let BuildSideState { + let BuildSideState:: { batches, num_rows, metrics, @@ -1459,13 +1551,12 @@ async fn collect_left_input( // Compute bounds for dynamic filter if enabled let bounds = match bounds_accumulators { - Some(accumulators) if num_rows > 0 => { - let bounds = accumulators + Some(accumulators) if num_rows > 0 => Some( + accumulators .into_iter() - .map(CollectLeftAccumulator::evaluate) - .collect::>>()?; - Some(bounds) - } + .map(|a| a.evaluate()) + .collect::>>()?, + ), _ => None, }; diff --git a/datafusion/physical-plan/src/joins/hash_join/mod.rs b/datafusion/physical-plan/src/joins/hash_join/mod.rs index 7f1e5cae13a3e..c2cb148786706 100644 --- a/datafusion/physical-plan/src/joins/hash_join/mod.rs +++ b/datafusion/physical-plan/src/joins/hash_join/mod.rs @@ -18,6 +18,8 @@ //! [`HashJoinExec`] Partitioned Hash Join Operator pub use exec::HashJoinExec; +pub use exec::{CollectLeftAccumulator, MinMaxLeftAccumulator}; +pub use shared_bounds::ColumnBounds; mod exec; mod shared_bounds; diff --git a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs index 73e65be686833..f1f1d08d72812 100644 --- a/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs +++ b/datafusion/physical-plan/src/joins/hash_join/shared_bounds.rs @@ -18,7 +18,7 @@ //! Utilities for shared bounds. Used in dynamic filter pushdown in Hash Joins. // TODO: include the link to the Dynamic Filter blog post. -use std::fmt; +use std::fmt::{self, Debug}; use std::sync::Arc; use crate::joins::PartitionMode; @@ -33,22 +33,64 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef}; use itertools::Itertools; use parking_lot::Mutex; +/// Trait representing some set of bounds for a column used in join dynamic filtering. +/// +/// Bounds could be min/max values, or some other type of custom bounds that can be represented by a physical expression. +/// +/// Refer to the [`MinMaxColumnBounds`] implementation for an example of min/max bounds. +pub trait ColumnBounds: Send + Sync + Debug { + /// Creates a physical expression representing the bounds for this column (left expression). + /// + /// # Arguments + /// + /// * `left_expr` - The left side physical expression for which to create bounds. + /// + /// # Returns + /// `Ok(Arc)` if creating the bounds expression succeeds, or an error otherwise. + fn physical_expr( + &self, + left_expr: Arc, + ) -> Result>; +} + /// Represents the minimum and maximum values for a specific column. /// Used in dynamic filter pushdown to establish value boundaries. #[derive(Debug, Clone, PartialEq)] -pub(crate) struct ColumnBounds { +pub(crate) struct MinMaxColumnBounds { /// The minimum value observed for this column min: ScalarValue, /// The maximum value observed for this column max: ScalarValue, } -impl ColumnBounds { +impl MinMaxColumnBounds { pub(crate) fn new(min: ScalarValue, max: ScalarValue) -> Self { Self { min, max } } } +impl ColumnBounds for MinMaxColumnBounds { + fn physical_expr( + &self, + left_expr: Arc, + ) -> Result> { + // Create predicate: col >= min AND col <= max + let min_expr = Arc::new(BinaryExpr::new( + Arc::clone(&left_expr), + Operator::GtEq, + lit(self.min.clone()), + )) as Arc; + let max_expr = Arc::new(BinaryExpr::new( + Arc::clone(&left_expr), + Operator::LtEq, + lit(self.max.clone()), + )) as Arc; + let range_expr = Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) + as Arc; + Ok(range_expr) + } +} + /// Represents the bounds for all join key columns from a single partition. /// This contains the min/max values computed from one partition's build-side data. #[derive(Debug, Clone)] @@ -57,11 +99,14 @@ pub(crate) struct PartitionBounds { partition: usize, /// Min/max bounds for each join key column in this partition. /// Index corresponds to the join key expression index. - column_bounds: Vec, + column_bounds: Vec>, } impl PartitionBounds { - pub(crate) fn new(partition: usize, column_bounds: Vec) -> Self { + pub(crate) fn new( + partition: usize, + column_bounds: Vec>, + ) -> Self { Self { partition, column_bounds, @@ -72,7 +117,10 @@ impl PartitionBounds { self.column_bounds.len() } - pub(crate) fn get_column_bounds(&self, index: usize) -> Option<&ColumnBounds> { + pub(crate) fn get_column_bounds( + &self, + index: usize, + ) -> Option<&Arc> { self.column_bounds.get(index) } } @@ -204,21 +252,9 @@ impl SharedBoundsAccumulator { for (col_idx, right_expr) in self.on_right.iter().enumerate() { if let Some(column_bounds) = partition_bounds.get_column_bounds(col_idx) { - // Create predicate: col >= min AND col <= max - let min_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::GtEq, - lit(column_bounds.min.clone()), - )) as Arc; - let max_expr = Arc::new(BinaryExpr::new( - Arc::clone(right_expr), - Operator::LtEq, - lit(column_bounds.max.clone()), - )) as Arc; - let range_expr = - Arc::new(BinaryExpr::new(min_expr, Operator::And, max_expr)) - as Arc; - column_predicates.push(range_expr); + let bounds_expr = + column_bounds.physical_expr(Arc::clone(right_expr))?; + column_predicates.push(bounds_expr); } } @@ -261,7 +297,7 @@ impl SharedBoundsAccumulator { pub(crate) fn report_partition_bounds( &self, partition: usize, - partition_bounds: Option>, + partition_bounds: Option>>, ) -> Result<()> { let mut inner = self.inner.lock(); @@ -289,7 +325,7 @@ impl SharedBoundsAccumulator { } } -impl fmt::Debug for SharedBoundsAccumulator { +impl Debug for SharedBoundsAccumulator { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "SharedBoundsAccumulator") } diff --git a/datafusion/physical-plan/src/joins/mod.rs b/datafusion/physical-plan/src/joins/mod.rs index 1d36db996434e..8cd7bd34d251f 100644 --- a/datafusion/physical-plan/src/joins/mod.rs +++ b/datafusion/physical-plan/src/joins/mod.rs @@ -20,7 +20,9 @@ use arrow::array::BooleanBufferBuilder; pub use cross_join::CrossJoinExec; use datafusion_physical_expr::PhysicalExprRef; -pub use hash_join::HashJoinExec; +pub use hash_join::{ + CollectLeftAccumulator, ColumnBounds, HashJoinExec, MinMaxLeftAccumulator, +}; pub use nested_loop_join::NestedLoopJoinExec; use parking_lot::Mutex; // Note: SortMergeJoin is not used in plans yet diff --git a/datafusion/pruning/src/lib.rs b/datafusion/pruning/src/lib.rs index cec4fab2262f8..8cb91f01faffd 100644 --- a/datafusion/pruning/src/lib.rs +++ b/datafusion/pruning/src/lib.rs @@ -20,6 +20,7 @@ mod pruning_predicate; pub use file_pruner::FilePruner; pub use pruning_predicate::{ - build_pruning_predicate, PredicateRewriter, PruningPredicate, PruningStatistics, - RequiredColumns, UnhandledPredicateHook, + build_pruning_predicate, build_statistics_record_batch, BoolVecBuilder, + PredicateRewriter, PruningPredicate, PruningStatistics, RequiredColumns, + StatisticsType, UnhandledPredicateHook, }; diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 5e92dbe227fdd..3decf07d05980 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -612,7 +612,7 @@ impl PruningPredicate { /// Builds the return `Vec` for [`PruningPredicate::prune`]. #[derive(Debug)] -struct BoolVecBuilder { +pub struct BoolVecBuilder { /// One element per container. Each element is /// * `true`: if the container has row that may pass the predicate /// * `false`: if the container has rows that DEFINITELY DO NOT pass the predicate @@ -621,7 +621,7 @@ struct BoolVecBuilder { impl BoolVecBuilder { /// Create a new `BoolVecBuilder` with `num_containers` elements - fn new(num_containers: usize) -> Self { + pub fn new(num_containers: usize) -> Self { Self { // assume by default all containers may pass the predicate inner: vec![true; num_containers], @@ -652,7 +652,7 @@ impl BoolVecBuilder { /// /// # Panics /// If `value` is not boolean - fn combine_value(&mut self, value: ColumnarValue) { + pub fn combine_value(&mut self, value: ColumnarValue) { match value { ColumnarValue::Array(array) => { self.combine_array(array.as_boolean()); @@ -669,12 +669,12 @@ impl BoolVecBuilder { } /// Convert this builder into a Vec of bools - fn build(self) -> Vec { + pub fn build(self) -> Vec { self.inner } /// Check all containers has rows that DEFINITELY DO NOT pass the predicate - fn check_all_pruned(&self) -> bool { + pub fn check_all_pruned(&self) -> bool { self.inner.iter().all(|&x| !x) } } @@ -899,7 +899,7 @@ impl From> for RequiredColumns { /// -------+-------- /// 5 | 1000 /// ``` -fn build_statistics_record_batch( +pub fn build_statistics_record_batch( statistics: &S, required_columns: &RequiredColumns, ) -> Result { @@ -1854,7 +1854,7 @@ fn wrap_null_count_check_expr( } #[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub(crate) enum StatisticsType { +pub enum StatisticsType { Min, Max, NullCount,