diff --git a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs index dd2c1960a6580..b91b0c11bd448 100644 --- a/datafusion/core/tests/physical_optimizer/limit_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/limit_pushdown.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::error::Result; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; -use datafusion_expr::Operator; +use datafusion_expr::{JoinType, Operator}; use datafusion_physical_expr::expressions::BinaryExpr; use datafusion_physical_expr::expressions::{col, lit}; use datafusion_physical_expr::{Partitioning, PhysicalSortExpr}; @@ -32,6 +32,7 @@ use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::empty::EmptyExec; use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::NestedLoopJoinExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::repartition::RepartitionExec; @@ -154,6 +155,16 @@ impl PartitionStream for DummyStreamPartition { } } +fn nested_loop_join_exec( + left: Arc, + right: Arc, + join_type: JoinType, +) -> Result> { + Ok(Arc::new(NestedLoopJoinExec::try_new( + left, right, None, &join_type, None, + )?)) +} + #[test] fn transforms_streaming_table_exec_into_fetching_version_when_skip_is_zero() -> Result<()> { @@ -486,3 +497,102 @@ fn merges_local_limit_with_global_limit() -> Result<()> { Ok(()) } + +#[test] +fn preserves_nested_global_limit() -> Result<()> { + // If there are multiple limits in an execution plan, they all need to be + // preserved in the optimized plan. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=1 + // NestedLoopJoinExec (Left) + // EmptyExec (left side) + // GlobalLimitExec: skip=2, fetch=1 + // NestedLoopJoinExec (Right) + // EmptyExec (left side) + // EmptyExec (right side) + let schema = create_schema(); + + // Build inner join: NestedLoopJoin(Empty, Empty) + let inner_left = empty_exec(Arc::clone(&schema)); + let inner_right = empty_exec(Arc::clone(&schema)); + let inner_join = nested_loop_join_exec(inner_left, inner_right, JoinType::Right)?; + + // Add inner limit: GlobalLimitExec: skip=2, fetch=1 + let inner_limit = global_limit_exec(inner_join, 2, Some(1)); + + // Build outer join: NestedLoopJoin(Empty, GlobalLimit) + let outer_left = empty_exec(Arc::clone(&schema)); + let outer_join = nested_loop_join_exec(outer_left, inner_limit, JoinType::Left)?; + + // Add outer limit: GlobalLimitExec: skip=1, fetch=1 + let outer_limit = global_limit_exec(outer_join, 1, Some(1)); + + let initial = get_plan_string(&outer_limit); + let expected_initial = [ + "GlobalLimitExec: skip=1, fetch=1", + " NestedLoopJoinExec: join_type=Left", + " EmptyExec", + " GlobalLimitExec: skip=2, fetch=1", + " NestedLoopJoinExec: join_type=Right", + " EmptyExec", + " EmptyExec", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let expected = [ + "GlobalLimitExec: skip=1, fetch=1", + " NestedLoopJoinExec: join_type=Left", + " EmptyExec", + " GlobalLimitExec: skip=2, fetch=1", + " NestedLoopJoinExec: join_type=Right", + " EmptyExec", + " EmptyExec", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} + +#[test] +fn preserves_skip_before_sort() -> Result<()> { + // If there's a limit with skip before a node that (1) supports fetch but + // (2) does not support limit pushdown, that limit should not be removed. + // + // Plan structure: + // GlobalLimitExec: skip=1, fetch=None + // SortExec: TopK(fetch=4) + // EmptyExec + let schema = create_schema(); + + let empty = empty_exec(Arc::clone(&schema)); + + let ordering = [PhysicalSortExpr { + expr: col("c1", &schema)?, + options: SortOptions::default(), + }]; + let sort = sort_exec(ordering, empty).with_fetch(Some(4)).unwrap(); + + let outer_limit = global_limit_exec(sort, 1, None); + + let initial = get_plan_string(&outer_limit); + let expected_initial = [ + "GlobalLimitExec: skip=1, fetch=None", + " SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false]", + " EmptyExec", + ]; + assert_eq!(initial, expected_initial); + + let after_optimize = + LimitPushdown::new().optimize(outer_limit, &ConfigOptions::new())?; + let expected = [ + "GlobalLimitExec: skip=1, fetch=3", + " SortExec: TopK(fetch=4), expr=[c1@0 ASC], preserve_partitioning=[false]", + " EmptyExec", + ]; + assert_eq!(get_plan_string(&after_optimize), expected); + + Ok(()) +} diff --git a/datafusion/physical-optimizer/src/limit_pushdown.rs b/datafusion/physical-optimizer/src/limit_pushdown.rs index 5887cb51a727b..4795414a52789 100644 --- a/datafusion/physical-optimizer/src/limit_pushdown.rs +++ b/datafusion/physical-optimizer/src/limit_pushdown.rs @@ -145,6 +145,7 @@ pub fn pushdown_limit_helper( ); global_state.skip = skip; global_state.fetch = fetch; + global_state.satisfied = false; // Now the global state has the most recent information, we can remove // the `LimitExec` plan. We will decide later if we should add it again @@ -162,7 +163,7 @@ pub fn pushdown_limit_helper( // If we have a non-limit operator with fetch capability, update global // state as necessary: if pushdown_plan.fetch().is_some() { - if global_state.fetch.is_none() { + if global_state.skip == 0 { global_state.satisfied = true; } (global_state.skip, global_state.fetch) = combine_limit( diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt index 067b23ac2fb01..d64a23481ce6d 100644 --- a/datafusion/sqllogictest/test_files/limit.slt +++ b/datafusion/sqllogictest/test_files/limit.slt @@ -711,8 +711,8 @@ ON t1.b = t2.b ORDER BY t1.b desc, c desc, c2 desc OFFSET 3 LIMIT 2; ---- -3 99 82 -3 99 79 +3 98 79 +3 97 96 statement ok drop table ordered_table; diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index 918c6e2811737..20b87a3f9f169 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -513,24 +513,27 @@ physical_plan 01)CoalescePartitionsExec: fetch=3 02)--UnionExec 03)----ProjectionExec: expr=[count(Int64(1))@0 as cnt] -04)------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] -05)--------CoalescePartitionsExec -06)----------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] -07)------------ProjectionExec: expr=[] -08)--------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] -09)----------------CoalesceBatchesExec: target_batch_size=2 -10)------------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 -11)--------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] -12)----------------------CoalesceBatchesExec: target_batch_size=2 -13)------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] -14)--------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -15)----------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true -16)----ProjectionExec: expr=[1 as cnt] -17)------PlaceholderRowExec -18)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] -19)------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] -20)--------ProjectionExec: expr=[1 as c1] -21)----------PlaceholderRowExec +04)------GlobalLimitExec: skip=0, fetch=3 +05)--------AggregateExec: mode=Final, gby=[], aggr=[count(Int64(1))] +06)----------CoalescePartitionsExec +07)------------AggregateExec: mode=Partial, gby=[], aggr=[count(Int64(1))] +08)--------------ProjectionExec: expr=[] +09)----------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[] +10)------------------CoalesceBatchesExec: target_batch_size=2 +11)--------------------RepartitionExec: partitioning=Hash([c1@0], 4), input_partitions=4 +12)----------------------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[] +13)------------------------CoalesceBatchesExec: target_batch_size=2 +14)--------------------------FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434, projection=[c1@0] +15)----------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +16)------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c13], file_type=csv, has_header=true +17)----ProjectionExec: expr=[1 as cnt] +18)------GlobalLimitExec: skip=0, fetch=3 +19)--------PlaceholderRowExec +20)----ProjectionExec: expr=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING@1 as cnt] +21)------GlobalLimitExec: skip=0, fetch=3 +22)--------BoundedWindowAggExec: wdw=[lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING: Ok(Field { name: "lead(b.c1,Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Rows, start_bound: Preceding(UInt64(NULL)), end_bound: Following(UInt64(NULL)), is_causal: false }], mode=[Sorted] +23)----------ProjectionExec: expr=[1 as c1] +24)------------PlaceholderRowExec ########