diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 647ad680674b0..834b0a97a47b0 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -207,6 +207,13 @@ pub trait Dialect: Send + Sync { Ok(None) } + /// Allows the dialect to support the QUALIFY clause + /// + /// Some dialects, like Postgres, do not support the QUALIFY clause + fn supports_qualify(&self) -> bool { + true + } + /// Allows the dialect to override logic of formatting datetime with tz into string. fn timestamp_with_tz_to_string(&self, dt: DateTime, _unit: TimeUnit) -> String { dt.to_string() @@ -274,6 +281,14 @@ impl Dialect for DefaultDialect { pub struct PostgreSqlDialect {} impl Dialect for PostgreSqlDialect { + fn supports_qualify(&self) -> bool { + false + } + + fn requires_derived_table_alias(&self) -> bool { + true + } + fn identifier_quote_style(&self, _: &str) -> Option { Some('"') } @@ -424,6 +439,10 @@ impl Dialect for DuckDBDialect { pub struct MySqlDialect {} impl Dialect for MySqlDialect { + fn supports_qualify(&self) -> bool { + false + } + fn identifier_quote_style(&self, _: &str) -> Option { Some('`') } @@ -485,6 +504,10 @@ impl Dialect for MySqlDialect { pub struct SqliteDialect {} impl Dialect for SqliteDialect { + fn supports_qualify(&self) -> bool { + false + } + fn identifier_quote_style(&self, _: &str) -> Option { Some('`') } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index b6c65614995a9..e7535338b7677 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -32,11 +32,11 @@ use super::{ }, Unparser, }; -use crate::unparser::ast::UnnestRelationBuilder; use crate::unparser::extension_unparser::{ UnparseToStatementResult, UnparseWithinStatementResult, }; use crate::unparser::utils::{find_unnest_node_until_relation, unproject_agg_exprs}; +use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify}; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, @@ -95,7 +95,10 @@ pub fn plan_to_sql(plan: &LogicalPlan) -> Result { impl Unparser<'_> { pub fn plan_to_sql(&self, plan: &LogicalPlan) -> Result { - let plan = normalize_union_schema(plan)?; + let mut plan = normalize_union_schema(plan)?; + if !self.dialect.supports_qualify() { + plan = rewrite_qualify(plan)?; + } match plan { LogicalPlan::Projection(_) @@ -428,6 +431,18 @@ impl Unparser<'_> { unproject_agg_exprs(filter.predicate.clone(), agg, None)?; let filter_expr = self.expr_to_sql(&unprojected)?; select.having(Some(filter_expr)); + } else if let (Some(window), true) = ( + find_window_nodes_within_select( + plan, + None, + select.already_projected(), + ), + self.dialect.supports_qualify(), + ) { + let unprojected = + unproject_window_exprs(filter.predicate.clone(), &window)?; + let filter_expr = self.expr_to_sql(&unprojected)?; + select.qualify(Some(filter_expr)); } else { let filter_expr = self.expr_to_sql(&filter.predicate)?; select.selection(Some(filter_expr)); diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index aa480cf4fff92..c961f1d6f1f0c 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -100,6 +100,72 @@ fn rewrite_sort_expr_for_union(exprs: Vec) -> Result> { Ok(sort_exprs) } +/// Rewrite Filter plans that have a Window as their input by inserting a SubqueryAlias. +/// +/// When a Filter directly operates on a Window plan, it can cause issues during SQL unparsing +/// because window functions in a WHERE clause are not valid SQL. The solution is to wrap +/// the Window plan in a SubqueryAlias, effectively creating a derived table. +/// +/// Example transformation: +/// +/// Filter: condition +/// Window: window_function +/// TableScan: table +/// +/// becomes: +/// +/// Filter: condition +/// SubqueryAlias: __qualify_subquery +/// Projection: table.column1, table.column2 +/// Window: window_function +/// TableScan: table +/// +pub(super) fn rewrite_qualify(plan: LogicalPlan) -> Result { + let transformed_plan = plan.transform_up(|plan| match plan { + // Check if the filter's input is a Window plan + LogicalPlan::Filter(mut filter) => { + if matches!(&*filter.input, LogicalPlan::Window(_)) { + // Create a SubqueryAlias around the Window plan + let qualifier = filter + .input + .schema() + .iter() + .find_map(|(q, _)| q) + .map(|q| q.to_string()) + .unwrap_or_else(|| "__qualify_subquery".to_string()); + + // for Postgres, name of column for 'rank() over (...)' is 'rank' + // but in Datafusion, it is 'rank() over (...)' + // without projection, it's still an invalid sql in Postgres + + let project_exprs = filter + .input + .schema() + .iter() + .map(|(_, f)| datafusion_expr::col(f.name()).alias(f.name())) + .collect::>(); + + let input = + datafusion_expr::LogicalPlanBuilder::from(Arc::clone(&filter.input)) + .project(project_exprs)? + .build()?; + + let subquery_alias = + datafusion_expr::SubqueryAlias::try_new(Arc::new(input), qualifier)?; + + filter.input = Arc::new(LogicalPlan::SubqueryAlias(subquery_alias)); + Ok(Transformed::yes(LogicalPlan::Filter(filter))) + } else { + Ok(Transformed::no(LogicalPlan::Filter(filter))) + } + } + + _ => Ok(Transformed::no(plan)), + }); + + transformed_plan.data() +} + /// Rewrite logic plan for query that order by columns are not in projections /// Plan before rewrite: /// diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 7aa982dcf3dd9..5f76afb763cff 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -21,12 +21,14 @@ use datafusion_common::{ assert_contains, Column, DFSchema, DFSchemaRef, DataFusionError, Result, TableReference, }; +use datafusion_expr::expr::{WindowFunction, WindowFunctionParams}; use datafusion_expr::test::function_stub::{ count_udaf, max_udaf, min_udaf, sum, sum_udaf, }; use datafusion_expr::{ cast, col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan, LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + WindowFrame, WindowFunctionDefinition, }; use datafusion_functions::unicode; use datafusion_functions_aggregate::grouping::grouping_udaf; @@ -2521,6 +2523,90 @@ fn test_unparse_left_semi_join_with_table_scan_projection() -> Result<()> { Ok(()) } +#[test] +fn test_unparse_window() -> Result<()> { + // SubqueryAlias: t + // Projection: t.k, t.v, rank() PARTITION BY [t.k] ORDER BY [t.v ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW AS r + // Filter: rank() PARTITION BY [t.k] ORDER BY [t.v ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW = UInt64(1) + // WindowAggr: windowExpr=[[rank() PARTITION BY [t.k] ORDER BY [t.v ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]] + // TableScan: t projection=[k, v] + + let schema = Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + let window_expr = Expr::WindowFunction(Box::new(WindowFunction { + fun: WindowFunctionDefinition::WindowUDF(rank_udwf()), + params: WindowFunctionParams { + args: vec![], + partition_by: vec![col("k")], + order_by: vec![col("v").sort(true, true)], + window_frame: WindowFrame::new(None), + null_treatment: None, + distinct: false, + filter: None, + }, + })); + let table = table_scan(Some("test"), &schema, Some(vec![0, 1]))?.build()?; + let plan = LogicalPlanBuilder::window_plan(table, vec![window_expr.clone()])?; + + let name = plan.schema().fields().last().unwrap().name().clone(); + let plan = LogicalPlanBuilder::from(plan) + .filter(col(name.clone()).eq(lit(1i64)))? + .project(vec![col("k"), col("v"), col(name)])? + .build()?; + + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "test"."k", "test"."v", "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v", rank() OVER (PARTITION BY "test"."k" ORDER BY "test"."v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM "test") AS "test" WHERE ("rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"# + ); + + let unparser = Unparser::new(&UnparserMySqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT `test`.`k`, `test`.`v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"# + ); + + let unparser = Unparser::new(&SqliteDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT `test`.`k`, `test`.`v`, `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM (SELECT `test`.`k` AS `k`, `test`.`v` AS `v`, rank() OVER (PARTITION BY `test`.`k` ORDER BY `test`.`v` ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS `rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` FROM `test`) AS `test` WHERE (`rank() PARTITION BY [test.k] ORDER BY [test.v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING` = 1)"# + ); + + let unparser = Unparser::new(&DefaultDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT test.k, test.v, rank() OVER (PARTITION BY test.k ORDER BY test.v ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test QUALIFY (rank() OVER (PARTITION BY test.k ORDER BY test.v ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) = 1)"# + ); + + // without table qualifier + let table = table_scan(Some("test"), &schema, Some(vec![0, 1]))?.build()?; + let table = LogicalPlanBuilder::from(table) + .project(vec![col("k").alias("k"), col("v").alias("v")])? + .build()?; + let plan = LogicalPlanBuilder::window_plan(table, vec![window_expr])?; + + let name = plan.schema().fields().last().unwrap().name().clone(); + let plan = LogicalPlanBuilder::from(plan) + .filter(col(name.clone()).eq(lit(1i64)))? + .project(vec![col("k"), col("v"), col(name)])? + .build()?; + + let unparser = Unparser::new(&UnparserPostgreSqlDialect {}); + let sql = unparser.plan_to_sql(&plan)?; + assert_snapshot!( + sql, + @r#"SELECT "k", "v", "rank() PARTITION BY [k] ORDER BY [v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "k" AS "k", "v" AS "v", rank() OVER (PARTITION BY "k" ORDER BY "v" ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS "rank() PARTITION BY [k] ORDER BY [v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" FROM (SELECT "test"."k" AS "k", "test"."v" AS "v" FROM "test") AS "derived_projection") AS "__qualify_subquery" WHERE ("rank() PARTITION BY [k] ORDER BY [v ASC NULLS FIRST] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" = 1)"# + ); + + Ok(()) +} + #[test] fn test_like_filter() { let statement = generate_round_trip_statement(