diff --git a/datafusion-examples/examples/parse_sql_expr.rs b/datafusion-examples/examples/parse_sql_expr.rs index 6444eb68b6b2..e6dfaf6c8a82 100644 --- a/datafusion-examples/examples/parse_sql_expr.rs +++ b/datafusion-examples/examples/parse_sql_expr.rs @@ -153,5 +153,14 @@ async fn round_trip_parse_sql_expr_demo() -> Result<()> { assert_eq!(sql, round_trip_sql); + // enable pretty-unparsing. This make the output more human-readable + // but can be problematic when passed to other SQL engines due to + // difference in precedence rules between DataFusion and target engines. + let unparser = Unparser::default().with_pretty(true); + + let pretty = "int_col < 5 OR double_col = 8"; + let pretty_round_trip_sql = unparser.expr_to_sql(&parsed_expr)?.to_string(); + assert_eq!(pretty, pretty_round_trip_sql); + Ok(()) } diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index bd708fe52bc1..f719a33fb624 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -31,9 +31,9 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; /// 1. [`simple_expr_to_sql_demo`]: Create a simple expression [`Exprs`] with /// fluent API and convert to sql suitable for passing to another database /// -/// 2. [`simple_expr_to_sql_demo_no_escape`] Create a simple expression -/// [`Exprs`] with fluent API and convert to sql without escaping column names -/// more suitable for displaying to humans. +/// 2. [`simple_expr_to_pretty_sql_demo`] Create a simple expression +/// [`Exprs`] with fluent API and convert to sql without extra parentheses, +/// suitable for displaying to humans /// /// 3. [`simple_expr_to_sql_demo_escape_mysql_style`]" Create a simple /// expression [`Exprs`] with fluent API and convert to sql escaping column @@ -49,6 +49,7 @@ use datafusion_sql::unparser::{plan_to_sql, Unparser}; async fn main() -> Result<()> { // See how to evaluate expressions simple_expr_to_sql_demo()?; + simple_expr_to_pretty_sql_demo()?; simple_expr_to_sql_demo_escape_mysql_style()?; simple_plan_to_sql_demo().await?; round_trip_plan_to_sql_demo().await?; @@ -64,6 +65,17 @@ fn simple_expr_to_sql_demo() -> Result<()> { Ok(()) } +/// DataFusioon can remove parentheses when converting an expression to SQL. +/// Note that output is intended for humans, not for other SQL engines, +/// as difference in precedence rules can cause expressions to be parsed differently. +fn simple_expr_to_pretty_sql_demo() -> Result<()> { + let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8))); + let unparser = Unparser::default().with_pretty(true); + let sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(sql, r#"a < 5 OR a = 8"#); + Ok(()) +} + /// DataFusion can convert expressions to SQL without escaping column names using /// using a custom dialect and an explicit unparser fn simple_expr_to_sql_demo_escape_mysql_style() -> Result<()> { diff --git a/datafusion/expr/src/operator.rs b/datafusion/expr/src/operator.rs index 742511822a0f..0cbf9f00821a 100644 --- a/datafusion/expr/src/operator.rs +++ b/datafusion/expr/src/operator.rs @@ -218,29 +218,23 @@ impl Operator { } /// Get the operator precedence - /// use as a reference + /// use as a reference pub fn precedence(&self) -> u8 { match self { Operator::Or => 5, Operator::And => 10, - Operator::NotEq - | Operator::Eq - | Operator::Lt - | Operator::LtEq - | Operator::Gt - | Operator::GtEq => 20, - Operator::Plus | Operator::Minus => 30, - Operator::Multiply | Operator::Divide | Operator::Modulo => 40, + Operator::Eq | Operator::NotEq | Operator::LtEq | Operator::GtEq => 15, + Operator::Lt | Operator::Gt => 20, + Operator::LikeMatch + | Operator::NotLikeMatch + | Operator::ILikeMatch + | Operator::NotILikeMatch => 25, Operator::IsDistinctFrom | Operator::IsNotDistinctFrom | Operator::RegexMatch | Operator::RegexNotMatch | Operator::RegexIMatch | Operator::RegexNotIMatch - | Operator::LikeMatch - | Operator::ILikeMatch - | Operator::NotLikeMatch - | Operator::NotILikeMatch | Operator::BitwiseAnd | Operator::BitwiseOr | Operator::BitwiseShiftLeft @@ -248,7 +242,9 @@ impl Operator { | Operator::BitwiseXor | Operator::StringConcat | Operator::AtArrow - | Operator::ArrowAt => 0, + | Operator::ArrowAt => 30, + Operator::Plus | Operator::Minus => 40, + Operator::Multiply | Operator::Divide | Operator::Modulo => 45, } } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index ad898de5987a..f67cd5928c79 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -30,8 +30,8 @@ use arrow_array::{Date32Array, Date64Array, PrimitiveArray}; use arrow_schema::DataType; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Expr as AstExpr, Function, FunctionArg, Ident, Interval, TimezoneInfo, - UnaryOperator, + self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval, + TimezoneInfo, UnaryOperator, }; use datafusion_common::{ @@ -101,8 +101,21 @@ pub fn expr_to_unparsed(expr: &Expr) -> Result { unparser.expr_to_unparsed(expr) } +const LOWEST: &BinaryOperator = &BinaryOperator::Or; +// closest precedence we have to IS operator is BitwiseAnd (any other) in PG docs +// (https://www.postgresql.org/docs/7.2/sql-precedence.html) +const IS: &BinaryOperator = &BinaryOperator::BitwiseAnd; + impl Unparser<'_> { pub fn expr_to_sql(&self, expr: &Expr) -> Result { + let mut root_expr = self.expr_to_sql_inner(expr)?; + if self.pretty { + root_expr = self.remove_unnecessary_nesting(root_expr, LOWEST, LOWEST); + } + Ok(root_expr) + } + + fn expr_to_sql_inner(&self, expr: &Expr) -> Result { match expr { Expr::InList(InList { expr, @@ -111,10 +124,10 @@ impl Unparser<'_> { }) => { let list_expr = list .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?; Ok(ast::Expr::InList { - expr: Box::new(self.expr_to_sql(expr)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), list: list_expr, negated: *negated, }) @@ -128,7 +141,7 @@ impl Unparser<'_> { if matches!(e, Expr::Wildcard { qualifier: None }) { Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard)) } else { - self.expr_to_sql(e).map(|e| { + self.expr_to_sql_inner(e).map(|e| { FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) }) } @@ -157,9 +170,9 @@ impl Unparser<'_> { low, high, }) => { - let sql_parser_expr = self.expr_to_sql(expr)?; - let sql_low = self.expr_to_sql(low)?; - let sql_high = self.expr_to_sql(high)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; + let sql_low = self.expr_to_sql_inner(low)?; + let sql_high = self.expr_to_sql_inner(high)?; Ok(ast::Expr::Nested(Box::new(self.between_op_to_sql( sql_parser_expr, *negated, @@ -169,8 +182,8 @@ impl Unparser<'_> { } Expr::Column(col) => self.col_to_sql(col), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let l = self.expr_to_sql(left.as_ref())?; - let r = self.expr_to_sql(right.as_ref())?; + let l = self.expr_to_sql_inner(left.as_ref())?; + let r = self.expr_to_sql_inner(right.as_ref())?; let op = self.op_to_sql(op)?; Ok(ast::Expr::Nested(Box::new(self.binary_op_to_sql(l, r, op)))) @@ -182,21 +195,21 @@ impl Unparser<'_> { }) => { let conditions = when_then_expr .iter() - .map(|(w, _)| self.expr_to_sql(w)) + .map(|(w, _)| self.expr_to_sql_inner(w)) .collect::>>()?; let results = when_then_expr .iter() - .map(|(_, t)| self.expr_to_sql(t)) + .map(|(_, t)| self.expr_to_sql_inner(t)) .collect::>>()?; let operand = match expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, None => None, }; let else_result = match else_expr.as_ref() { - Some(e) => match self.expr_to_sql(e) { + Some(e) => match self.expr_to_sql_inner(e) { Ok(sql_expr) => Some(Box::new(sql_expr)), Err(_) => None, }, @@ -211,7 +224,7 @@ impl Unparser<'_> { }) } Expr::Cast(Cast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::Cast, expr: Box::new(inner_expr), @@ -220,7 +233,7 @@ impl Unparser<'_> { }) } Expr::Literal(value) => Ok(self.scalar_to_sql(value)?), - Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql(expr), + Expr::Alias(Alias { expr, name: _, .. }) => self.expr_to_sql_inner(expr), Expr::WindowFunction(WindowFunction { fun, args, @@ -255,7 +268,7 @@ impl Unparser<'_> { window_name: None, partition_by: partition_by .iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?, order_by, window_frame: Some(ast::WindowFrame { @@ -296,8 +309,8 @@ impl Unparser<'_> { case_insensitive: _, }) => Ok(ast::Expr::Like { negated: *negated, - expr: Box::new(self.expr_to_sql(expr)?), - pattern: Box::new(self.expr_to_sql(pattern)?), + expr: Box::new(self.expr_to_sql_inner(expr)?), + pattern: Box::new(self.expr_to_sql_inner(pattern)?), escape_char: escape_char.map(|c| c.to_string()), }), Expr::AggregateFunction(agg) => { @@ -305,7 +318,7 @@ impl Unparser<'_> { let args = self.function_args_to_sql(&agg.args)?; let filter = match &agg.filter { - Some(filter) => Some(Box::new(self.expr_to_sql(filter)?)), + Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; Ok(ast::Expr::Function(Function { @@ -339,7 +352,7 @@ impl Unparser<'_> { Ok(ast::Expr::Subquery(sub_query)) } Expr::InSubquery(insubq) => { - let inexpr = Box::new(self.expr_to_sql(insubq.expr.as_ref())?); + let inexpr = Box::new(self.expr_to_sql_inner(insubq.expr.as_ref())?); let sub_statement = self.plan_to_sql(insubq.subquery.subquery.as_ref())?; let sub_query = if let ast::Statement::Query(inner_query) = sub_statement @@ -377,38 +390,38 @@ impl Unparser<'_> { nulls_first: _, }) => plan_err!("Sort expression should be handled by expr_to_unparsed"), Expr::IsNull(expr) => { - Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotNull(expr) => { - Ok(ast::Expr::IsNotNull(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsNull(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotNull(expr) => Ok(ast::Expr::IsNotNull(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsTrue(expr) => { - Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotTrue(expr) => { - Ok(ast::Expr::IsNotTrue(Box::new(self.expr_to_sql(expr)?))) + Ok(ast::Expr::IsTrue(Box::new(self.expr_to_sql_inner(expr)?))) } + Expr::IsNotTrue(expr) => Ok(ast::Expr::IsNotTrue(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::IsFalse(expr) => { - Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotFalse(expr) => { - Ok(ast::Expr::IsNotFalse(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsUnknown(expr) => { - Ok(ast::Expr::IsUnknown(Box::new(self.expr_to_sql(expr)?))) - } - Expr::IsNotUnknown(expr) => { - Ok(ast::Expr::IsNotUnknown(Box::new(self.expr_to_sql(expr)?))) - } + Ok(ast::Expr::IsFalse(Box::new(self.expr_to_sql_inner(expr)?))) + } + Expr::IsNotFalse(expr) => Ok(ast::Expr::IsNotFalse(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsUnknown(expr) => Ok(ast::Expr::IsUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), + Expr::IsNotUnknown(expr) => Ok(ast::Expr::IsNotUnknown(Box::new( + self.expr_to_sql_inner(expr)?, + ))), Expr::Not(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Not, expr: Box::new(sql_parser_expr), }) } Expr::Negative(expr) => { - let sql_parser_expr = self.expr_to_sql(expr)?; + let sql_parser_expr = self.expr_to_sql_inner(expr)?; Ok(AstExpr::UnaryOp { op: UnaryOperator::Minus, expr: Box::new(sql_parser_expr), @@ -432,7 +445,7 @@ impl Unparser<'_> { }) } Expr::TryCast(TryCast { expr, data_type }) => { - let inner_expr = self.expr_to_sql(expr)?; + let inner_expr = self.expr_to_sql_inner(expr)?; Ok(ast::Expr::Cast { kind: ast::CastKind::TryCast, expr: Box::new(inner_expr), @@ -449,7 +462,7 @@ impl Unparser<'_> { .iter() .map(|set| { set.iter() - .map(|e| self.expr_to_sql(e)) + .map(|e| self.expr_to_sql_inner(e)) .collect::>>() }) .collect::>>()?; @@ -460,7 +473,7 @@ impl Unparser<'_> { let expr_ast_sets = cube .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -470,7 +483,7 @@ impl Unparser<'_> { let expr_ast_sets: Vec> = rollup .iter() .map(|e| { - let sql = self.expr_to_sql(e)?; + let sql = self.expr_to_sql_inner(e)?; Ok(vec![sql]) }) .collect::>>()?; @@ -603,6 +616,88 @@ impl Unparser<'_> { } } + /// Given an expression of the form `((a + b) * (c * d))`, + /// the parenthesing is redundant if the precedence of the nested expression is already higher + /// than the surrounding operators' precedence. The above expression would become + /// `(a + b) * c * d`. + /// + /// Also note that when fetching the precedence of a nested expression, we ignore other nested + /// expressions, so precedence of expr `(a * (b + c))` equals `*` and not `+`. + fn remove_unnecessary_nesting( + &self, + expr: ast::Expr, + left_op: &BinaryOperator, + right_op: &BinaryOperator, + ) -> ast::Expr { + match expr { + ast::Expr::Nested(nested) => { + let surrounding_precedence = self + .sql_op_precedence(left_op) + .max(self.sql_op_precedence(right_op)); + + let inner_precedence = self.inner_precedence(&nested); + + let not_associative = + matches!(left_op, BinaryOperator::Minus | BinaryOperator::Divide); + + if inner_precedence == surrounding_precedence && not_associative { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } else if inner_precedence >= surrounding_precedence { + self.remove_unnecessary_nesting(*nested, left_op, right_op) + } else { + ast::Expr::Nested(Box::new( + self.remove_unnecessary_nesting(*nested, LOWEST, LOWEST), + )) + } + } + ast::Expr::BinaryOp { left, op, right } => ast::Expr::BinaryOp { + left: Box::new(self.remove_unnecessary_nesting(*left, left_op, &op)), + right: Box::new(self.remove_unnecessary_nesting(*right, &op, right_op)), + op, + }, + ast::Expr::IsTrue(expr) => ast::Expr::IsTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotTrue(expr) => ast::Expr::IsNotTrue(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsFalse(expr) => ast::Expr::IsFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotFalse(expr) => ast::Expr::IsNotFalse(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNull(expr) => ast::Expr::IsNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotNull(expr) => ast::Expr::IsNotNull(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsUnknown(expr) => ast::Expr::IsUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + ast::Expr::IsNotUnknown(expr) => ast::Expr::IsNotUnknown(Box::new( + self.remove_unnecessary_nesting(*expr, left_op, IS), + )), + _ => expr, + } + } + + fn inner_precedence(&self, expr: &ast::Expr) -> u8 { + match expr { + ast::Expr::Nested(_) | ast::Expr::Identifier(_) | ast::Expr::Value(_) => 100, + ast::Expr::BinaryOp { op, .. } => self.sql_op_precedence(op), + // closest precedence we currently have to Between is PGLikeMatch + // (https://www.postgresql.org/docs/7.2/sql-precedence.html) + ast::Expr::Between { .. } => { + self.sql_op_precedence(&ast::BinaryOperator::PGLikeMatch) + } + _ => 0, + } + } + pub(super) fn between_op_to_sql( &self, expr: ast::Expr, @@ -618,6 +713,48 @@ impl Unparser<'_> { } } + fn sql_op_precedence(&self, op: &BinaryOperator) -> u8 { + match self.sql_to_op(op) { + Ok(op) => op.precedence(), + Err(_) => 0, + } + } + + fn sql_to_op(&self, op: &BinaryOperator) -> Result { + match op { + ast::BinaryOperator::Eq => Ok(Operator::Eq), + ast::BinaryOperator::NotEq => Ok(Operator::NotEq), + ast::BinaryOperator::Lt => Ok(Operator::Lt), + ast::BinaryOperator::LtEq => Ok(Operator::LtEq), + ast::BinaryOperator::Gt => Ok(Operator::Gt), + ast::BinaryOperator::GtEq => Ok(Operator::GtEq), + ast::BinaryOperator::Plus => Ok(Operator::Plus), + ast::BinaryOperator::Minus => Ok(Operator::Minus), + ast::BinaryOperator::Multiply => Ok(Operator::Multiply), + ast::BinaryOperator::Divide => Ok(Operator::Divide), + ast::BinaryOperator::Modulo => Ok(Operator::Modulo), + ast::BinaryOperator::And => Ok(Operator::And), + ast::BinaryOperator::Or => Ok(Operator::Or), + ast::BinaryOperator::PGRegexMatch => Ok(Operator::RegexMatch), + ast::BinaryOperator::PGRegexIMatch => Ok(Operator::RegexIMatch), + ast::BinaryOperator::PGRegexNotMatch => Ok(Operator::RegexNotMatch), + ast::BinaryOperator::PGRegexNotIMatch => Ok(Operator::RegexNotIMatch), + ast::BinaryOperator::PGILikeMatch => Ok(Operator::ILikeMatch), + ast::BinaryOperator::PGNotLikeMatch => Ok(Operator::NotLikeMatch), + ast::BinaryOperator::PGLikeMatch => Ok(Operator::LikeMatch), + ast::BinaryOperator::PGNotILikeMatch => Ok(Operator::NotILikeMatch), + ast::BinaryOperator::BitwiseAnd => Ok(Operator::BitwiseAnd), + ast::BinaryOperator::BitwiseOr => Ok(Operator::BitwiseOr), + ast::BinaryOperator::BitwiseXor => Ok(Operator::BitwiseXor), + ast::BinaryOperator::PGBitwiseShiftRight => Ok(Operator::BitwiseShiftRight), + ast::BinaryOperator::PGBitwiseShiftLeft => Ok(Operator::BitwiseShiftLeft), + ast::BinaryOperator::StringConcat => Ok(Operator::StringConcat), + ast::BinaryOperator::AtArrow => Ok(Operator::AtArrow), + ast::BinaryOperator::ArrowAt => Ok(Operator::ArrowAt), + _ => not_impl_err!("unsupported operation: {op:?}"), + } + } + fn op_to_sql(&self, op: &Operator) -> Result { match op { Operator::Eq => Ok(ast::BinaryOperator::Eq), @@ -1515,6 +1652,7 @@ mod tests { Ok(()) } + #[test] fn custom_dialect() -> Result<()> { let dialect = CustomDialect::new(Some('\'')); diff --git a/datafusion/sql/src/unparser/mod.rs b/datafusion/sql/src/unparser/mod.rs index fbbed4972b17..e5ffbc8a212a 100644 --- a/datafusion/sql/src/unparser/mod.rs +++ b/datafusion/sql/src/unparser/mod.rs @@ -29,11 +29,23 @@ pub mod dialect; pub struct Unparser<'a> { dialect: &'a dyn Dialect, + pretty: bool, } impl<'a> Unparser<'a> { pub fn new(dialect: &'a dyn Dialect) -> Self { - Self { dialect } + Self { + dialect, + pretty: false, + } + } + + /// Allow unparser to remove parenthesis according to the precedence rules of DataFusion. + /// This might make it invalid SQL for other SQL query engines with different precedence + /// rules, even if its valid for DataFusion. + pub fn with_pretty(mut self, pretty: bool) -> Self { + self.pretty = pretty; + self } } @@ -41,6 +53,7 @@ impl<'a> Default for Unparser<'a> { fn default() -> Self { Self { dialect: &DefaultDialect {}, + pretty: false, } } } diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 374403d853f9..91295b2e8aae 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -104,26 +104,26 @@ fn roundtrip_statement() -> Result<()> { "select id, count(*) as cnt from (select p1.id as id from person p1 inner join person p2 on p1.id=p2.id) group by id", "select id, count(*), first_name from person group by first_name, id", "select id, sum(age), first_name from person group by first_name, id", - "select id, count(*), first_name - from person + "select id, count(*), first_name + from person where id!=3 and first_name=='test' - group by first_name, id + group by first_name, id having count(*)>5 and count(*)<10 order by count(*)", - r#"select id, count("First Name") as count_first_name, "Last Name" + r#"select id, count("First Name") as count_first_name, "Last Name" from person_quoted_cols where id!=3 and "First Name"=='test' - group by "Last Name", id + group by "Last Name", id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"select p.id, count("First Name") as count_first_name, - "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) + "Last Name", sum(qp.id/p.id - (select sum(id) from person_quoted_cols) ) / (select count(*) from person) from (select id, "First Name", "Last Name" from person_quoted_cols) qp inner join (select * from person) p on p.id = qp.id - where p.id!=3 and "First Name"=='test' and qp.id in + where p.id!=3 and "First Name"=='test' and qp.id in (select id from (select id, count(*) from person group by id having count(*) > 0)) - group by "Last Name", p.id + group by "Last Name", p.id having count_first_name>5 and count_first_name<10 order by count_first_name, "Last Name""#, r#"SELECT j1_string as string FROM j1 @@ -134,12 +134,12 @@ fn roundtrip_statement() -> Result<()> { SELECT j2_string as string FROM j2 ORDER BY string DESC LIMIT 10"#, - "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), - last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + "SELECT id, count(*) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + last_name, sum(id) over (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), first_name from person", - r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, - "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", + "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", ]; // For each test sql string, we transform as follows: @@ -314,3 +314,78 @@ fn test_table_references_in_plan_to_sql() { "SELECT \"table\".id, \"table\".\"value\" FROM \"table\"", ); } + +#[test] +fn test_pretty_roundtrip() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("id", DataType::Utf8, false), + Field::new("age", DataType::Utf8, false), + ]); + + let df_schema = DFSchema::try_from(schema)?; + + let context = MockContextProvider::default(); + let sql_to_rel = SqlToRel::new(&context); + + let unparser = Unparser::default().with_pretty(true); + + let sql_to_pretty_unparse = vec![ + ("((id < 5) OR (age = 8))", "id < 5 OR age = 8"), + ("((id + 5) * (age * 8))", "(id + 5) * age * 8"), + ("(3 + (5 * 6) * 3)", "3 + 5 * 6 * 3"), + ("((3 * (5 + 6)) * 3)", "3 * (5 + 6) * 3"), + ("((3 AND (5 OR 6)) * 3)", "(3 AND (5 OR 6)) * 3"), + ("((3 + (5 + 6)) * 3)", "(3 + 5 + 6) * 3"), + ("((3 + (5 + 6)) + 3)", "3 + 5 + 6 + 3"), + ("3 + 5 + 6 + 3", "3 + 5 + 6 + 3"), + ("3 + (5 + (6 + 3))", "3 + 5 + 6 + 3"), + ("3 + ((5 + 6) + 3)", "3 + 5 + 6 + 3"), + ("(3 + 5) + (6 + 3)", "3 + 5 + 6 + 3"), + ("((3 + 5) + (6 + 3))", "3 + 5 + 6 + 3"), + ( + "((id > 10) OR (age BETWEEN 10 AND 20))", + "id > 10 OR age BETWEEN 10 AND 20", + ), + ( + "((id > 10) * (age BETWEEN 10 AND 20))", + "(id > 10) * (age BETWEEN 10 AND 20)", + ), + ("id - (age - 8)", "id - (age - 8)"), + ("((id - age) - 8)", "id - age - 8"), + ("(id OR (age - 8))", "id OR age - 8"), + ("(id / (age - 8))", "id / (age - 8)"), + ("((id / age) * 8)", "id / age * 8"), + ("((age + 10) < 20) IS TRUE", "(age + 10 < 20) IS TRUE"), + ( + "(20 > (age + 5)) IS NOT FALSE", + "(20 > age + 5) IS NOT FALSE", + ), + ("(true AND false) IS FALSE", "(true AND false) IS FALSE"), + ("true AND (false IS FALSE)", "true AND false IS FALSE"), + ]; + + for (sql, pretty) in sql_to_pretty_unparse.iter() { + let sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(sql)? + .parse_expr()?; + let expr = + sql_to_rel.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new())?; + let round_trip_sql = unparser.expr_to_sql(&expr)?.to_string(); + assert_eq!(pretty.to_string(), round_trip_sql); + + // verify that the pretty string parses to the same underlying Expr + let pretty_sql_expr = Parser::new(&GenericDialect {}) + .try_with_sql(pretty)? + .parse_expr()?; + + let pretty_expr = sql_to_rel.sql_to_expr( + pretty_sql_expr, + &df_schema, + &mut PlannerContext::new(), + )?; + + assert_eq!(expr.to_string(), pretty_expr.to_string()); + } + + Ok(()) +}