diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index eca2eb4fd0ec..87453f81ee3d 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -16,7 +16,7 @@ // under the License. use regex::Regex; -use sqlparser::keywords::ALL_KEYWORDS; +use sqlparser::{ast, keywords::ALL_KEYWORDS}; /// `Dialect` to use for Unparsing /// @@ -45,6 +45,17 @@ pub trait Dialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::PostgresVerbose } + + // The SQL type to use for Arrow Utf8 unparsing + // Most dialects use VARCHAR, but some, like MySQL, require CHAR + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Varchar(None) + } + // The SQL type to use for Arrow LargeUtf8 unparsing + // Most dialects use TEXT, but some, like MySQL, require CHAR + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Text + } } /// `IntervalStyle` to use for unparsing @@ -103,6 +114,14 @@ impl Dialect for MySqlDialect { fn interval_style(&self) -> IntervalStyle { IntervalStyle::MySQL } + + fn utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + ast::DataType::Char(None) + } } pub struct SqliteDialect {} @@ -118,6 +137,8 @@ pub struct CustomDialect { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, } impl Default for CustomDialect { @@ -127,6 +148,8 @@ impl Default for CustomDialect { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::SQLStandard, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, } } } @@ -158,6 +181,14 @@ impl Dialect for CustomDialect { fn interval_style(&self) -> IntervalStyle { self.interval_style } + + fn utf8_cast_dtype(&self) -> ast::DataType { + self.utf8_cast_dtype.clone() + } + + fn large_utf8_cast_dtype(&self) -> ast::DataType { + self.large_utf8_cast_dtype.clone() + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -179,6 +210,8 @@ pub struct CustomDialectBuilder { supports_nulls_first_in_sort: bool, use_timestamp_for_date64: bool, interval_style: IntervalStyle, + utf8_cast_dtype: ast::DataType, + large_utf8_cast_dtype: ast::DataType, } impl Default for CustomDialectBuilder { @@ -194,6 +227,8 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: true, use_timestamp_for_date64: false, interval_style: IntervalStyle::PostgresVerbose, + utf8_cast_dtype: ast::DataType::Varchar(None), + large_utf8_cast_dtype: ast::DataType::Text, } } @@ -203,6 +238,8 @@ impl CustomDialectBuilder { supports_nulls_first_in_sort: self.supports_nulls_first_in_sort, use_timestamp_for_date64: self.use_timestamp_for_date64, interval_style: self.interval_style, + utf8_cast_dtype: self.utf8_cast_dtype, + large_utf8_cast_dtype: self.large_utf8_cast_dtype, } } @@ -235,4 +272,17 @@ impl CustomDialectBuilder { self.interval_style = interval_style; self } + + pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self { + self.utf8_cast_dtype = utf8_cast_dtype; + self + } + + pub fn with_large_utf8_cast_dtype( + mut self, + large_utf8_cast_dtype: ast::DataType, + ) -> Self { + self.large_utf8_cast_dtype = large_utf8_cast_dtype; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index e6b67b5d9fb2..950e7e11288a 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1275,8 +1275,8 @@ impl Unparser<'_> { DataType::BinaryView => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } - DataType::Utf8 => Ok(ast::DataType::Varchar(None)), - DataType::LargeUtf8 => Ok(ast::DataType::Text), + DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()), + DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()), DataType::Utf8View => { not_impl_err!("Unsupported DataType: conversion: {data_type:?}") } @@ -1936,4 +1936,34 @@ mod tests { assert_eq!(actual, expected); } } + + #[test] + fn custom_dialect_use_char_for_utf8_cast() -> Result<()> { + let default_dialect = CustomDialectBuilder::default().build(); + let mysql_custom_dialect = CustomDialectBuilder::new() + .with_utf8_cast_dtype(ast::DataType::Char(None)) + .with_large_utf8_cast_dtype(ast::DataType::Char(None)) + .build(); + + for (dialect, data_type, identifier) in [ + (&default_dialect, DataType::Utf8, "VARCHAR"), + (&default_dialect, DataType::LargeUtf8, "TEXT"), + (&mysql_custom_dialect, DataType::Utf8, "CHAR"), + (&mysql_custom_dialect, DataType::LargeUtf8, "CHAR"), + ] { + let unparser = Unparser::new(dialect); + + let expr = Expr::Cast(Cast { + expr: Box::new(col("a")), + data_type, + }); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{}", ast); + let expected = format!(r#"CAST(a AS {identifier})"#); + + assert_eq!(actual, expected); + } + Ok(()) + } }