diff --git a/datafusion/sql/src/unparser/ast.rs b/datafusion/sql/src/unparser/ast.rs index a92b1f1e35476..188eec2f8a139 100644 --- a/datafusion/sql/src/unparser/ast.rs +++ b/datafusion/sql/src/unparser/ast.rs @@ -398,16 +398,18 @@ impl Default for TableWithJoinsBuilder { #[derive(Clone)] pub struct RelationBuilder { - relation: Option, + pub relation: Option, } #[allow(dead_code)] #[derive(Clone)] #[allow(clippy::large_enum_variant)] -enum TableFactorBuilder { +pub enum TableFactorBuilder { Table(TableRelationBuilder), Derived(DerivedRelationBuilder), Unnest(UnnestRelationBuilder), + Function(FunctionRelationBuilder), + TableFunction(TableFunctionRelationBuilder), Empty, } @@ -430,6 +432,16 @@ impl RelationBuilder { self } + pub fn function(&mut self, value: FunctionRelationBuilder) -> &mut Self { + self.relation = Some(TableFactorBuilder::Function(value)); + self + } + + pub fn table_function(&mut self, value: TableFunctionRelationBuilder) -> &mut Self { + self.relation = Some(TableFactorBuilder::TableFunction(value)); + self + } + pub fn empty(&mut self) -> &mut Self { self.relation = Some(TableFactorBuilder::Empty); self @@ -446,6 +458,12 @@ impl RelationBuilder { Some(TableFactorBuilder::Unnest(ref mut rel_builder)) => { rel_builder.alias = value; } + Some(TableFactorBuilder::Function(ref mut rel_builder)) => { + rel_builder.alias = value; + } + Some(TableFactorBuilder::TableFunction(ref mut rel_builder)) => { + rel_builder.alias = value; + } Some(TableFactorBuilder::Empty) => (), None => (), } @@ -456,6 +474,8 @@ impl RelationBuilder { Some(TableFactorBuilder::Table(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Derived(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Unnest(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::Function(ref value)) => Some(value.build()?), + Some(TableFactorBuilder::TableFunction(ref value)) => Some(value.build()?), Some(TableFactorBuilder::Empty) => None, None => return Err(Into::into(UninitializedFieldError::from("relation"))), }) @@ -662,6 +682,96 @@ impl Default for UnnestRelationBuilder { } } +#[derive(Clone)] +pub struct FunctionRelationBuilder { + lateral: bool, + name: ast::ObjectName, + args: Vec, + alias: Option, +} + +#[allow(dead_code)] +impl FunctionRelationBuilder { + pub fn lateral(&mut self, value: bool) -> &mut Self { + self.lateral = value; + self + } + + pub fn name(&mut self, value: ast::ObjectName) -> &mut Self { + self.name = value; + self + } + + pub fn args(&mut self, value: Vec) -> &mut Self { + self.args = value; + self + } + + pub fn alias(&mut self, value: Option) -> &mut Self { + self.alias = value; + self + } + + pub fn build(&self) -> Result { + Ok(ast::TableFactor::Function { + lateral: self.lateral, + name: self.name.clone(), + args: self.args.clone(), + alias: self.alias.clone(), + }) + } + + fn create_empty() -> Self { + Self { + lateral: Default::default(), + name: ast::ObjectName(vec![]), + args: Default::default(), + alias: Default::default(), + } + } +} + +#[derive(Clone)] +pub struct TableFunctionRelationBuilder { + pub expr: Option, + pub alias: Option, +} + +impl TableFunctionRelationBuilder { + pub fn expr(&mut self, value: ast::Expr) -> &mut Self { + self.expr = Some(value); + self + } + + pub fn alias(&mut self, value: Option) -> &mut Self { + self.alias = value; + self + } + + pub fn build(&self) -> Result { + Ok(ast::TableFactor::TableFunction { + expr: match self.expr { + Some(ref value) => value.clone(), + None => return Err(Into::into(UninitializedFieldError::from("expr"))), + }, + alias: self.alias.clone(), + }) + } + + fn create_empty() -> Self { + Self { + expr: Default::default(), + alias: Default::default(), + } + } +} + +impl Default for TableFunctionRelationBuilder { + fn default() -> Self { + Self::create_empty() + } +} + /// Runtime error when a `build()` method is called and one or more required fields /// do not have a value. #[derive(Debug, Clone)] diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index d47077376beea..dd7de11595609 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -17,14 +17,19 @@ use std::{collections::HashMap, sync::Arc}; +use crate::unparser::ast::{ + RelationBuilder, TableFactorBuilder, TableFunctionRelationBuilder, +}; + use super::{ utils::character_length_to_sql, utils::date_part_to_sql, utils::sqlite_date_trunc_to_sql, utils::sqlite_from_unixtime_to_sql, Unparser, }; use arrow::datatypes::TimeUnit; -use datafusion_common::Result; -use datafusion_expr::Expr; +use datafusion_common::{plan_err, Result}; +use datafusion_expr::{Expr, LogicalPlan, Unnest}; use regex::Regex; +use sqlparser::ast::ValueWithSpan; use sqlparser::tokenizer::Span; use sqlparser::{ ast::{ @@ -198,6 +203,35 @@ pub trait Dialect: Send + Sync { false } + /// Allow the dialect implement to unparse the unnest plan as the dialect-specific flattened + /// array table factor. + /// + /// Some dialects like Snowflake require FLATTEN function to unnest arrays in the FROM clause. + /// + fn unparse_unnest_table_factor( + &self, + _unnest: &Unnest, + _columns: &[Ident], + _unparser: &Unparser, + ) -> Result> { + Ok(None) + } + + /// Allows the dialect to override relation alias unparsing if the dialect has specific rules. + /// Returns true if the dialect has overridden the alias unparsing, false to use default unparsing. + /// + /// This is useful for dialects that need to modify the alias based on specific conditions. For example, + /// in Snowflake, when using the FLATTEN function, the alias of the derived table needs to be adjusted + /// to match the output columns of the FLATTEN function. It can be used with [`unparse_unnest_table_factor`] to achieve this. + /// See [`SnowflakeDialect`] implementation for an example. + fn relation_alias_overrides( + &self, + _relation_builder: &mut RelationBuilder, + _alias: Option<&ast::TableAlias>, + ) -> bool { + false + } + /// Allows the dialect to override column alias unparsing if the dialect has specific rules. /// Returns None if the default unparsing should be used, or Some(String) if there is /// a custom implementation for the alias. @@ -576,6 +610,161 @@ impl Dialect for MsSqlDialect { } } +pub static UNNAMED_SNOWFLAKE_FLATTEN_SUBQUERY_PREFIX: &str = "__unnamed_flatten_subquery"; + +#[derive(Default)] +pub struct SnowflakeDialect {} + +impl Dialect for SnowflakeDialect { + fn identifier_quote_style(&self, _: &str) -> Option { + Some('"') + } + + fn unnest_as_table_factor(&self) -> bool { + true + } + + fn unparse_unnest_table_factor( + &self, + unnest: &Unnest, + columns: &[Ident], + unparser: &Unparser, + ) -> Result> { + let LogicalPlan::Projection(projection) = unnest.input.as_ref() else { + return Ok(None); + }; + + if !matches!(projection.input.as_ref(), LogicalPlan::EmptyRelation(_)) { + // It may be possible that UNNEST is used as a source for the query. + // However, at this point, we don't yet know if it is just a single expression + // from another source or if it's from UNNEST. + // + // Unnest(Projection(EmptyRelation)) denotes a case with `UNNEST([...])`, + // which is normally safe to unnest as a table factor. + // However, in the future, more comprehensive checks can be added here. + return Ok(None); + }; + + let mut table_function_relation = TableFunctionRelationBuilder::default(); + let mut exprs = projection + .expr + .iter() + .map(|e| unparser.expr_to_sql(e)) + .collect::>>()?; + + if exprs.len() != 1 { + // Snowflake FLATTEN function only supports a single argument. + return plan_err!( + "Only support one argument for Snowflake FLATTEN, found {}", + exprs.len() + ); + } + + if columns.len() != 1 { + // Snowflake FLATTEN function only supports a single output column. + return plan_err!( + "Only support one output column for Snowflake FLATTEN, found {}", + columns.len() + ); + } + + exprs.extend(vec![ + ast::Expr::Value(ValueWithSpan { + value: ast::Value::SingleQuotedString("".to_string()), + span: Span::empty(), + }), + ast::Expr::Value(ValueWithSpan { + value: ast::Value::Boolean(false), + span: Span::empty(), + }), + ast::Expr::Value(ValueWithSpan { + value: ast::Value::Boolean(false), + span: Span::empty(), + }), + ast::Expr::Value(ValueWithSpan { + value: ast::Value::SingleQuotedString("ARRAY".to_string()), + span: Span::empty(), + }), + ]); + + // To get the flattened result, we need to override the output columns of the FLATTEN function. + // The 4th column corresponds to the flattened value, which we will alias to the desired output column name. + // https://docs.snowflake.com/en/sql-reference/functions/flatten#output + let column_alias = vec![ + unparser.new_ident_quoted_if_needs("SEQ".to_string()), + unparser.new_ident_quoted_if_needs("KEY".to_string()), + unparser.new_ident_quoted_if_needs("PATH".to_string()), + unparser.new_ident_quoted_if_needs("INDEX".to_string()), + columns[0].clone(), + unparser.new_ident_quoted_if_needs("THIS".to_string()), + ]; + + let func_expr = ast::Expr::Function(Function { + name: vec![Ident::new("FLATTEN")].into(), + uses_odbc_syntax: false, + parameters: ast::FunctionArguments::None, + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + args: exprs + .into_iter() + .map(|e| ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))) + .collect(), + duplicate_treatment: None, + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + }); + table_function_relation.expr(func_expr); + table_function_relation.alias(Some( + unparser.new_table_alias( + unparser + .alias_generator + .next(UNNAMED_SNOWFLAKE_FLATTEN_SUBQUERY_PREFIX), + column_alias, + ), + )); + Ok(Some(TableFactorBuilder::TableFunction( + table_function_relation, + ))) + } + + fn relation_alias_overrides( + &self, + relation_builder: &mut RelationBuilder, + alias: Option<&ast::TableAlias>, + ) -> bool { + // When using FLATTEN function, we need to adjust the alias of the derived table + // to match the output columns of the FLATTEN function. The 4th column corresponds + // to the flattened value, which we will alias to the desired output column name. + if let Some(TableFactorBuilder::TableFunction(rel_builder)) = + relation_builder.relation.as_mut() + { + if let Some(value) = &alias { + if let Some(alias) = rel_builder.alias.as_mut() { + if alias + .name + .value + .starts_with(UNNAMED_SNOWFLAKE_FLATTEN_SUBQUERY_PREFIX) + && value.columns.len() == 1 + { + let mut new_columns = alias.columns.clone(); + new_columns[4] = value.columns[0].clone(); + let new_alias = ast::TableAlias { + name: value.name.clone(), + columns: new_columns, + }; + rel_builder.alias = Some(new_alias); + return true; + } + } + } + } + false + } +} + pub struct CustomDialect { identifier_quote_style: Option, supports_nulls_first_in_sort: bool, @@ -783,6 +972,7 @@ pub struct CustomDialectBuilder { window_func_support_window_frame: bool, full_qualified_col: bool, unnest_as_table_factor: bool, + unnest_to_flattened_table_factor: bool, } impl Default for CustomDialectBuilder { @@ -817,6 +1007,7 @@ impl CustomDialectBuilder { window_func_support_window_frame: true, full_qualified_col: false, unnest_as_table_factor: false, + unnest_to_flattened_table_factor: false, } } @@ -983,4 +1174,12 @@ impl CustomDialectBuilder { self.unnest_as_table_factor = unnest_as_table_factor; self } + + pub fn with_unnest_to_flattened_table_factor( + mut self, + unnest_to_flattened_table_factor: bool, + ) -> Self { + self.unnest_to_flattened_table_factor = unnest_to_flattened_table_factor; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index f6ea0243982d2..f7011ec65be39 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -813,8 +813,7 @@ impl Unparser<'_> { .collect::>>() } - /// This function can create an identifier with or without quotes based on the dialect rules - pub(super) fn new_ident_quoted_if_needs(&self, ident: String) -> Ident { + pub fn new_ident_quoted_if_needs(&self, ident: String) -> Ident { let quote_style = self.dialect.identifier_quote_style(&ident); Ident { value: ident, diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index 05b472dc92a93..ebcae0a1a868e 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -25,6 +25,7 @@ mod utils; use self::dialect::{DefaultDialect, Dialect}; use crate::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; +use datafusion_common::alias::AliasGenerator; pub use expr::expr_to_sql; pub use plan::plan_to_sql; use std::sync::Arc; @@ -58,6 +59,7 @@ pub struct Unparser<'a> { dialect: &'a dyn Dialect, pretty: bool, extension_unparsers: Vec>, + pub alias_generator: AliasGenerator, } impl<'a> Unparser<'a> { @@ -66,6 +68,7 @@ impl<'a> Unparser<'a> { dialect, pretty: false, extension_unparsers: vec![], + alias_generator: AliasGenerator::new(), } } @@ -136,6 +139,7 @@ impl Default for Unparser<'_> { dialect: &DefaultDialect {}, pretty: false, extension_unparsers: vec![], + alias_generator: AliasGenerator::new(), } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index d86681d281e0f..350b0efb63d5c 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -22,8 +22,7 @@ use super::{ }, rewrite::{ inject_column_aliases_into_subquery, normalize_union_schema, - rewrite_plan_for_sort_on_non_projected_fields, - subquery_alias_inner_query_and_columns, TableAliasRewriter, + rewrite_plan_for_sort_on_non_projected_fields, TableAliasRewriter, }, utils::{ find_agg_node_within_select, find_unnest_node_within_select, @@ -32,7 +31,7 @@ use super::{ }, Unparser, }; -use crate::unparser::ast::UnnestRelationBuilder; +use crate::unparser::ast::{TableFactorBuilder, UnnestRelationBuilder}; use crate::unparser::extension_unparser::{ UnparseToStatementResult, UnparseWithinStatementResult, }; @@ -377,21 +376,6 @@ impl Unparser<'_> { } else { None }; - if self.dialect.unnest_as_table_factor() && unnest_input_type.is_some() { - if let LogicalPlan::Unnest(unnest) = &p.input.as_ref() { - if let Some(unnest_relation) = - self.try_unnest_to_table_factor_sql(unnest)? - { - relation.unnest(unnest_relation); - return self.select_to_sql_recursively( - p.input.as_ref(), - query, - select, - relation, - ); - } - } - } // If it's a unnest projection, we should provide the table column alias // to provide a column name for the unnest relation. @@ -405,6 +389,36 @@ impl Unparser<'_> { } else { vec![] }; + + if self.dialect.unnest_as_table_factor() && unnest_input_type.is_some() { + if let LogicalPlan::Unnest(unnest) = &p.input.as_ref() { + if let Some(table_factor) = + self.unparse_unnest_table_factor(unnest, &columns)? + { + match table_factor { + TableFactorBuilder::Unnest(unnest) => { + relation.unnest(unnest) + } + TableFactorBuilder::TableFunction(table_function) => { + relation.table_function(table_function) + } + _ => { + return internal_err!( + "Unexpected table factor type for unnest" + ); + } + }; + + return self.select_to_sql_recursively( + p.input.as_ref(), + query, + select, + relation, + ); + } + } + } + // Projection can be top-level plan for derived table if select.already_projected() { return self.derive_with_dialect_alias( @@ -814,7 +828,7 @@ impl Unparser<'_> { } LogicalPlan::SubqueryAlias(plan_alias) => { let (plan, mut columns) = - subquery_alias_inner_query_and_columns(plan_alias); + self.subquery_alias_inner_query_and_columns(plan_alias); let unparsed_table_scan = Self::unparse_table_scan_pushdown( plan, Some(plan_alias.alias.clone()), @@ -854,10 +868,16 @@ impl Unparser<'_> { self.select_to_sql_recursively(&plan, query, select, relation)?; } - relation.alias(Some( - self.new_table_alias(plan_alias.alias.table().to_string(), columns), - )); + let new_alias = + self.new_table_alias(plan_alias.alias.table().to_string(), columns); + if self + .dialect + .relation_alias_overrides(relation, Some(&new_alias)) + { + return Ok(()); + } + relation.alias(Some(new_alias)); Ok(()) } LogicalPlan::Union(union) => { @@ -1019,6 +1039,24 @@ impl Unparser<'_> { None } + fn unparse_unnest_table_factor( + &self, + unnest: &Unnest, + columns: &[Ident], + ) -> Result> { + let dialect_flatten_relation = self + .dialect + .unparse_unnest_table_factor(unnest, columns, self)?; + if dialect_flatten_relation.is_some() { + return Ok(dialect_flatten_relation); + } + + if let Some(unnest_relation) = self.try_unnest_to_table_factor_sql(unnest)? { + return Ok(Some(TableFactorBuilder::Unnest(unnest_relation))); + } + Ok(None) + } + fn try_unnest_to_table_factor_sql( &self, unnest: &Unnest, @@ -1391,7 +1429,7 @@ impl Unparser<'_> { self.binary_op_to_sql(lhs, rhs, ast::BinaryOperator::And) } - fn new_table_alias(&self, alias: String, columns: Vec) -> ast::TableAlias { + pub fn new_table_alias(&self, alias: String, columns: Vec) -> ast::TableAlias { let columns = columns .into_iter() .map(|ident| TableAliasColumnDef { diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index aa480cf4fff92..0160173beba1e 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -27,6 +27,8 @@ use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX}; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; +use crate::unparser::Unparser; + /// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions. /// /// DataFusion will return an error if two columns in the schema have the same name with no table qualifiers. @@ -190,83 +192,6 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields( } } -/// This logic is to work out the columns and inner query for SubqueryAlias plan for some types of -/// subquery or unnest -/// - `(SELECT column_a as a from table) AS A` -/// - `(SELECT column_a from table) AS A (a)` -/// - `SELECT * FROM t1 CROSS JOIN UNNEST(t1.c1) AS u(c1)` (see [find_unnest_column_alias]) -/// -/// A roundtrip example for table alias with columns -/// -/// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) -/// -/// LogicPlan: -/// Projection: c.id -/// SubqueryAlias: c -/// Projection: j1.j1_id AS id -/// Projection: j1.j1_id -/// TableScan: j1 -/// -/// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS -/// id FROM (SELECT j1.j1_id FROM j1)) AS c`. -/// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table -/// `(SELECT j1.j1_id FROM j1)` -/// -/// With this logic, the unparsed query will be: -/// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` -/// -/// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` -/// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and -/// Column in the Projections. Once the parser side is fixed, this logic should work -pub(super) fn subquery_alias_inner_query_and_columns( - subquery_alias: &datafusion_expr::SubqueryAlias, -) -> (&LogicalPlan, Vec) { - let plan: &LogicalPlan = subquery_alias.input.as_ref(); - - if let LogicalPlan::Subquery(subquery) = plan { - let (inner_projection, Some(column)) = - find_unnest_column_alias(subquery.subquery.as_ref()) - else { - return (plan, vec![]); - }; - return (inner_projection, vec![Ident::new(column)]); - } - - let LogicalPlan::Projection(outer_projections) = plan else { - return (plan, vec![]); - }; - - // Check if it's projection inside projection - let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else { - return (plan, vec![]); - }; - - let mut columns: Vec = vec![]; - // Check if the inner projection and outer projection have a matching pattern like - // Projection: j1.j1_id AS id - // Projection: j1.j1_id - for (i, inner_expr) in inner_projection.expr.iter().enumerate() { - let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { - return (plan, vec![]); - }; - - // Inner projection schema fields store the projection name which is used in outer - // projection expr - let inner_expr_string = match inner_expr { - Expr::Column(_) => inner_expr.to_string(), - _ => inner_projection.schema.field(i).name().clone(), - }; - - if outer_alias.expr.to_string() != inner_expr_string { - return (plan, vec![]); - }; - - columns.push(outer_alias.name.as_str().into()); - } - - (outer_projections.input.as_ref(), columns) -} - /// Try to find the column alias for UNNEST in the inner projection. /// For example: /// ```sql @@ -383,6 +308,90 @@ fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> { } } +impl<'a> Unparser<'a> { + /// This logic is to work out the columns and inner query for SubqueryAlias plan for some types of + /// subquery or unnest + /// - `(SELECT column_a as a from table) AS A` + /// - `(SELECT column_a from table) AS A (a)` + /// - `SELECT * FROM t1 CROSS JOIN UNNEST(t1.c1) AS u(c1)` (see [find_unnest_column_alias]) + /// + /// A roundtrip example for table alias with columns + /// + /// query: SELECT id FROM (SELECT j1_id from j1) AS c (id) + /// + /// LogicPlan: + /// Projection: c.id + /// SubqueryAlias: c + /// Projection: j1.j1_id AS id + /// Projection: j1.j1_id + /// TableScan: j1 + /// + /// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS + /// id FROM (SELECT j1.j1_id FROM j1)) AS c`. + /// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table + /// `(SELECT j1.j1_id FROM j1)` + /// + /// With this logic, the unparsed query will be: + /// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)` + /// + /// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)` + /// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and + /// Column in the Projections. Once the parser side is fixed, this logic should work + pub(super) fn subquery_alias_inner_query_and_columns( + &'a self, + subquery_alias: &'a datafusion_expr::SubqueryAlias, + ) -> (&'a LogicalPlan, Vec) { + let plan: &LogicalPlan = subquery_alias.input.as_ref(); + + if let LogicalPlan::Subquery(subquery) = plan { + let (inner_projection, Some(column)) = + find_unnest_column_alias(subquery.subquery.as_ref()) + else { + return (plan, vec![]); + }; + return ( + inner_projection, + vec![self.new_ident_quoted_if_needs(column)], + ); + } + + let LogicalPlan::Projection(outer_projections) = plan else { + return (plan, vec![]); + }; + + // Check if it's projection inside projection + let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) + else { + return (plan, vec![]); + }; + + let mut columns: Vec = vec![]; + // Check if the inner projection and outer projection have a matching pattern like + // Projection: j1.j1_id AS id + // Projection: j1.j1_id + for (i, inner_expr) in inner_projection.expr.iter().enumerate() { + let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else { + return (plan, vec![]); + }; + + // Inner projection schema fields store the projection name which is used in outer + // projection expr + let inner_expr_string = match inner_expr { + Expr::Column(_) => inner_expr.to_string(), + _ => inner_projection.schema.field(i).name().clone(), + }; + + if outer_alias.expr.to_string() != inner_expr_string { + return (plan, vec![]); + }; + + columns.push(self.new_ident_quoted_if_needs(outer_alias.name.clone())); + } + + (outer_projections.input.as_ref(), columns) + } +} + /// A `TreeNodeRewriter` implementation that rewrites `Expr::Column` expressions by /// replacing the column's name with an alias if the column exists in the provided schema. /// diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 09a50a27d7fd7..807c9412f11d5 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -36,7 +36,7 @@ use datafusion_sql::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_sql::unparser::dialect::{ BigQueryDialect, CustomDialectBuilder, DefaultDialect as UnparserDefaultDialect, DefaultDialect, Dialect as UnparserDialect, MySqlDialect as UnparserMySqlDialect, - PostgreSqlDialect as UnparserPostgreSqlDialect, SqliteDialect, + PostgreSqlDialect as UnparserPostgreSqlDialect, SnowflakeDialect, SqliteDialect, }; use datafusion_sql::unparser::{expr_to_sql, plan_to_sql, Unparser}; use insta::assert_snapshot; @@ -2517,6 +2517,57 @@ fn test_unparse_left_semi_join_with_table_scan_projection() -> Result<()> { Ok(()) } +#[test] +fn test_unparse_unnest_to_table_flatten() -> Result<()> { + let unparser_dialect = SnowflakeDialect {}; + let unparser = Unparser::new(&unparser_dialect); + + let plan = sql_to_plan("SELECT * FROM UNNEST([1,2,3])")?; + assert_snapshot!( + unparser.plan_to_sql(&plan).unwrap(), + @r#"SELECT "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))" FROM TABLE(FLATTEN([1, 2, 3], '', false, false, 'ARRAY')) AS "__unnamed_flatten_subquery_1" ("SEQ", "KEY", "PATH", "INDEX", "UNNEST(make_array(Int64(1),Int64(2),Int64(3)))", "THIS")"# + ); + + let plan = sql_to_plan("SELECT * FROM UNNEST([1,2,3]) t(a)")?; + assert_snapshot!( + unparser.plan_to_sql(&plan).unwrap(), + @r#"SELECT "t"."a" FROM TABLE(FLATTEN([1, 2, 3], '', false, false, 'ARRAY')) AS "t" ("SEQ", "KEY", "PATH", "INDEX", "a", "THIS")"# + ); + + let plan = sql_to_plan("SELECT * FROM unnest_table, UNNEST(unnest_table.array_col)")?; + assert_snapshot!( + unparser.plan_to_sql(&plan).unwrap(), + @r#"SELECT "unnest_table"."array_col", "unnest_table"."struct_col", "UNNEST(outer_ref(unnest_table.array_col))" FROM "unnest_table" CROSS JOIN TABLE(FLATTEN("unnest_table"."array_col", '', false, false, 'ARRAY')) AS "__unnamed_flatten_subquery_4" ("SEQ", "KEY", "PATH", "INDEX", "UNNEST(outer_ref(unnest_table.array_col))", "THIS")"# + ); + + let plan = + sql_to_plan("SELECT t.a FROM unnest_table, UNNEST(unnest_table.array_col) t(a)")?; + assert_snapshot!( + unparser.plan_to_sql(&plan).unwrap(), + @r#"SELECT "t"."a" FROM "unnest_table" CROSS JOIN TABLE(FLATTEN("unnest_table"."array_col", '', false, false, 'ARRAY')) AS "t" ("SEQ", "KEY", "PATH", "INDEX", "a", "THIS")"# + ); + + Ok(()) +} + +fn sql_to_plan(sql: &str) -> Result { + let dialect = GenericDialect {}; + let statement = Parser::new(&dialect).try_with_sql(sql)?.parse_statement()?; + let state = MockSessionState::default() + .with_aggregate_function(sum_udaf()) + .with_aggregate_function(max_udaf()) + .with_aggregate_function(grouping_udaf()) + .with_window_function(rank_udwf()) + .with_scalar_function(Arc::new(unicode::substr().as_ref().clone())) + .with_scalar_function(make_array_udf()) + .with_expr_planner(Arc::new(CoreFunctionPlanner::default())) + .with_expr_planner(Arc::new(NestedFunctionPlanner)) + .with_expr_planner(Arc::new(FieldAccessPlanner)); + let context = MockContextProvider { state }; + let sql_to_rel = SqlToRel::new(&context); + sql_to_rel.sql_statement_to_plan(statement) +} + #[test] fn test_like_filter() { let statement = generate_round_trip_statement(