Skip to content

Commit bb3e0d8

Browse files
authored
Use Expr::qualified_name() and Column::new() to extract partition keys from window and aggregate operators (#17757)
* Use `Expr::qualified_name()` and `Column::new()` to extract partition keys Using `Expr::schema_name()` and `Column::from_qualified_name()` could incorrectly parse the column name. * Use `Expr::qualified_name()` to extract group by keys * Retrain dataframe tests with filters and aggregates
1 parent 05426bc commit bb3e0d8

File tree

2 files changed

+83
-18
lines changed

2 files changed

+83
-18
lines changed

datafusion/core/tests/dataframe/mod.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -667,12 +667,12 @@ async fn test_aggregate_with_pk2() -> Result<()> {
667667
let df = df.filter(predicate)?;
668668
assert_snapshot!(
669669
physical_plan_to_string(&df).await,
670-
@r###"
671-
CoalesceBatchesExec: target_batch_size=8192
672-
FilterExec: id@0 = 1 AND name@1 = a
673-
AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]
670+
@r"
671+
AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=Sorted
672+
CoalesceBatchesExec: target_batch_size=8192
673+
FilterExec: id@0 = 1 AND name@1 = a
674674
DataSourceExec: partitions=1, partition_sizes=[1]
675-
"###
675+
"
676676
);
677677

678678
// Since id and name are functionally dependant, we can use name among expression
@@ -716,12 +716,12 @@ async fn test_aggregate_with_pk3() -> Result<()> {
716716
let df = df.select(vec![col("id"), col("name")])?;
717717
assert_snapshot!(
718718
physical_plan_to_string(&df).await,
719-
@r###"
720-
CoalesceBatchesExec: target_batch_size=8192
721-
FilterExec: id@0 = 1
722-
AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[]
719+
@r"
720+
AggregateExec: mode=Single, gby=[id@0 as id, name@1 as name], aggr=[], ordering_mode=PartiallySorted([0])
721+
CoalesceBatchesExec: target_batch_size=8192
722+
FilterExec: id@0 = 1
723723
DataSourceExec: partitions=1, partition_sizes=[1]
724-
"###
724+
"
725725
);
726726

727727
// Since id and name are functionally dependant, we can use name among expression
@@ -767,12 +767,12 @@ async fn test_aggregate_with_pk4() -> Result<()> {
767767
// columns are not used.
768768
assert_snapshot!(
769769
physical_plan_to_string(&df).await,
770-
@r###"
771-
CoalesceBatchesExec: target_batch_size=8192
772-
FilterExec: id@0 = 1
773-
AggregateExec: mode=Single, gby=[id@0 as id], aggr=[]
770+
@r"
771+
AggregateExec: mode=Single, gby=[id@0 as id], aggr=[], ordering_mode=Sorted
772+
CoalesceBatchesExec: target_batch_size=8192
773+
FilterExec: id@0 = 1
774774
DataSourceExec: partitions=1, partition_sizes=[1]
775-
"###
775+
"
776776
);
777777

778778
let df_results = df.collect().await?;

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -978,8 +978,11 @@ impl OptimizerRule for PushDownFilter {
978978
let group_expr_columns = agg
979979
.group_expr
980980
.iter()
981-
.map(|e| Ok(Column::from_qualified_name(e.schema_name().to_string())))
982-
.collect::<Result<HashSet<_>>>()?;
981+
.map(|e| {
982+
let (relation, name) = e.qualified_name();
983+
Column::new(relation, name)
984+
})
985+
.collect::<HashSet<_>>();
983986

984987
let predicates = split_conjunction_owned(filter.predicate);
985988

@@ -1047,7 +1050,10 @@ impl OptimizerRule for PushDownFilter {
10471050
func.params
10481051
.partition_by
10491052
.iter()
1050-
.map(|c| Column::from_qualified_name(c.schema_name().to_string()))
1053+
.map(|c| {
1054+
let (relation, name) = c.qualified_name();
1055+
Column::new(relation, name)
1056+
})
10511057
.collect::<HashSet<_>>()
10521058
};
10531059
let potential_partition_keys = window
@@ -1567,6 +1573,30 @@ mod tests {
15671573
)
15681574
}
15691575

1576+
/// verifies that filters with unusual column names are pushed down through aggregate operators
1577+
#[test]
1578+
fn filter_move_agg_special() -> Result<()> {
1579+
let schema = Schema::new(vec![
1580+
Field::new("$a", DataType::UInt32, false),
1581+
Field::new("$b", DataType::UInt32, false),
1582+
Field::new("$c", DataType::UInt32, false),
1583+
]);
1584+
let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1585+
1586+
let plan = LogicalPlanBuilder::from(table_scan)
1587+
.aggregate(vec![col("$a")], vec![sum(col("$b")).alias("total_salary")])?
1588+
.filter(col("$a").gt(lit(10i64)))?
1589+
.build()?;
1590+
// filter of key aggregation is commutative
1591+
assert_optimized_plan_equal!(
1592+
plan,
1593+
@r"
1594+
Aggregate: groupBy=[[test.$a]], aggr=[[sum(test.$b) AS total_salary]]
1595+
TableScan: test, full_filters=[test.$a > Int64(10)]
1596+
"
1597+
)
1598+
}
1599+
15701600
#[test]
15711601
fn filter_complex_group_by() -> Result<()> {
15721602
let table_scan = test_table_scan()?;
@@ -1647,6 +1677,41 @@ mod tests {
16471677
)
16481678
}
16491679

1680+
/// verifies that filters with unusual identifier names are pushed down through window functions
1681+
#[test]
1682+
fn filter_window_special_identifier() -> Result<()> {
1683+
let schema = Schema::new(vec![
1684+
Field::new("$a", DataType::UInt32, false),
1685+
Field::new("$b", DataType::UInt32, false),
1686+
Field::new("$c", DataType::UInt32, false),
1687+
]);
1688+
let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
1689+
1690+
let window = Expr::from(WindowFunction::new(
1691+
WindowFunctionDefinition::WindowUDF(
1692+
datafusion_functions_window::rank::rank_udwf(),
1693+
),
1694+
vec![],
1695+
))
1696+
.partition_by(vec![col("$a"), col("$b")])
1697+
.order_by(vec![col("$c").sort(true, true)])
1698+
.build()
1699+
.unwrap();
1700+
1701+
let plan = LogicalPlanBuilder::from(table_scan)
1702+
.window(vec![window])?
1703+
.filter(col("$b").gt(lit(10i64)))?
1704+
.build()?;
1705+
1706+
assert_optimized_plan_equal!(
1707+
plan,
1708+
@r"
1709+
WindowAggr: windowExpr=[[rank() PARTITION BY [test.$a, test.$b] ORDER BY [test.$c ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]
1710+
TableScan: test, full_filters=[test.$b > Int64(10)]
1711+
"
1712+
)
1713+
}
1714+
16501715
/// verifies that when partitioning by 'a' and 'b', and filtering by 'a' and 'b', both 'a' and
16511716
/// 'b' are pushed
16521717
#[test]

0 commit comments

Comments
 (0)