From 47b5650104e7e9d6e12391b7c9ea7188fda538ef Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:23:32 -0700 Subject: [PATCH] Add dialect param to use double precision for float64 in Postgres (#11495) * Add dialect param to use double precision for float64 in Postgres * return ast data type instead of bool * Fix errors in merging * fix --- datafusion/sql/src/unparser/dialect.rs | 28 ++++++++++++++++++++++++ datafusion/sql/src/unparser/expr.rs | 30 +++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 1e82fc2b3c1b..ed0cfddc3827 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -46,11 +46,18 @@ pub trait Dialect: Send + Sync { IntervalStyle::PostgresVerbose } + // Does the dialect use DOUBLE PRECISION to represent Float64 rather than DOUBLE? + // E.g. Postgres uses DOUBLE PRECISION instead of DOUBLE + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + sqlparser::ast::DataType::Double + } + // The SQL type to use for Arrow Utf8 unparsing // Most dialects use VARCHAR, but some, like MySQL, require CHAR fn utf8_cast_dtype(&self) -> ast::DataType { ast::DataType::Varchar(None) } + // The SQL type to use for Arrow LargeUtf8 unparsing // Most dialects use TEXT, but some, like MySQL, require CHAR fn large_utf8_cast_dtype(&self) -> ast::DataType { @@ -98,6 +105,10 @@ impl Dialect for PostgreSqlDialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::PostgresVerbose } + + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + sqlparser::ast::DataType::DoublePrecision + } } pub struct MySqlDialect {} @@ -137,6 +148,7 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + float64_ast_dtype: sqlparser::ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, } @@ -148,6 +160,7 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, + float64_ast_dtype: sqlparser::ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, } @@ -182,6 +195,10 @@ impl Dialect for CustomDialect { self.interval_style } + fn float64_ast_dtype(&self) -> sqlparser::ast::DataType { + self.float64_ast_dtype.clone() + } + fn utf8_cast_dtype(&self) -> ast::DataType { self.utf8_cast_dtype.clone() } @@ -210,6 +227,7 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + float64_ast_dtype: sqlparser::ast::DataType, utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, } @@ -227,6 +245,7 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, + float64_ast_dtype: sqlparser::ast::DataType::Double, utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, } @@ -238,6 +257,7 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, use_timestamp_for_date64: self.use_timestamp_for_date64, interval_style: self.interval_style, + float64_ast_dtype: self.float64_ast_dtype, utf8_cast_dtype: self.utf8_cast_dtype, large_utf8_cast_dtype: self.large_utf8_cast_dtype, } @@ -273,6 +293,14 @@ impl CustomDialectBuilder { self } + pub fn with_float64_ast_dtype( + mut self, + float64_ast_dtype: sqlparser::ast::DataType, + ) -> Self { + self.float64_ast_dtype = float64_ast_dtype; + self + } + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { self.utf8_cast_dtype = utf8_cast_dtype; self diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 950e7e11288a..2f7854c1a183 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1240,7 +1240,7 @@ impl Unparser<'_> { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } DataType::Float32 => Ok(ast::DataType::Float(None)), - DataType::Float64 => Ok(ast::DataType::Double), + DataType::Float64 => Ok(self.dialect.float64_ast_dtype()), DataType::Timestamp(_, tz) => { let tz_info = match tz { Some(_) => TimezoneInfo::WithTimeZone, @@ -1822,6 +1822,34 @@ mod tests { Ok(()) } + #[test] + fn custom_dialect_float64_ast_dtype() -> Result<()> { + for (float64_ast_dtype, identifier) in [ + (sqlparser::ast::DataType::Double, "DOUBLE"), + ( + sqlparser::ast::DataType::DoublePrecision, + "DOUBLE PRECISION", + ), + ] { + let dialect = CustomDialectBuilder::new() + .with_float64_ast_dtype(float64_ast_dtype) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type: DataType::Float64, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + + let expected = format!(r#"CAST(a AS {identifier})"#); + assert_eq!(actual, expected); + } + Ok(()) + } + #[test] fn customer_dialect_support_nulls_first_in_ort() -> Result<()> { let tests: Vec<(Expr, &str, bool)> = vec![