diff --git a/Cargo.lock b/Cargo.lock index bbf64d5262e29..7b09121595d67 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1936,6 +1936,7 @@ dependencies = [ "clap 4.5.48", "ctor", "datafusion", + "datafusion-common", "dirs", "env_logger", "futures", diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml index d186cd711945d..53744e6c609b8 100644 --- a/datafusion-cli/Cargo.toml +++ b/datafusion-cli/Cargo.toml @@ -40,7 +40,7 @@ async-trait = { workspace = true } aws-config = "1.8.7" aws-credential-types = "1.2.7" chrono = { workspace = true } -clap = { version = "4.5.47", features = ["derive", "cargo"] } +clap = { version = "4.5.47", features = ["cargo", "derive"] } datafusion = { workspace = true, features = [ "avro", "compression", @@ -55,6 +55,7 @@ datafusion = { workspace = true, features = [ "sql", "unicode_expressions", ] } +datafusion-common = { workspace = true } dirs = "6.0.0" env_logger = { workspace = true } futures = { workspace = true } @@ -65,7 +66,7 @@ parking_lot = { workspace = true } parquet = { workspace = true, default-features = false } regex = { workspace = true } rustyline = "17.0" -tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot", "signal"] } +tokio = { workspace = true, features = ["macros", "parking_lot", "rt", "rt-multi-thread", "signal", "sync"] } url = { workspace = true } [dev-dependencies] diff --git a/datafusion-cli/src/helper.rs b/datafusion-cli/src/helper.rs index 64c34c4737369..219637b3460e6 100644 --- a/datafusion-cli/src/helper.rs +++ b/datafusion-cli/src/helper.rs @@ -24,6 +24,7 @@ use crate::highlighter::{NoSyntaxHighlighter, SyntaxHighlighter}; use datafusion::sql::parser::{DFParser, Statement}; use datafusion::sql::sqlparser::dialect::dialect_from_str; +use datafusion_common::config::Dialect; use rustyline::completion::{Completer, FilenameCompleter, Pair}; use rustyline::error::ReadlineError; @@ -34,12 +35,12 @@ use rustyline::{Context, Helper, Result}; pub struct CliHelper { completer: FilenameCompleter, - dialect: String, + dialect: Dialect, highlighter: Box, } impl CliHelper { - pub fn new(dialect: &str, color: bool) -> Self { + pub fn new(dialect: &Dialect, color: bool) -> Self { let highlighter: Box = if !color { Box::new(NoSyntaxHighlighter {}) } else { @@ -47,20 +48,20 @@ impl CliHelper { }; Self { completer: FilenameCompleter::new(), - dialect: dialect.into(), + dialect: *dialect, highlighter, } } - pub fn set_dialect(&mut self, dialect: &str) { - if dialect != self.dialect { - self.dialect = dialect.to_string(); + pub fn set_dialect(&mut self, dialect: &Dialect) { + if *dialect != self.dialect { + self.dialect = *dialect; } } fn validate_input(&self, input: &str) -> Result { if let Some(sql) = input.strip_suffix(';') { - let dialect = match dialect_from_str(&self.dialect) { + let dialect = match dialect_from_str(self.dialect) { Some(dialect) => dialect, None => { return Ok(ValidationResult::Invalid(Some(format!( @@ -97,7 +98,7 @@ impl CliHelper { impl Default for CliHelper { fn default() -> Self { - Self::new("generic", false) + Self::new(&Dialect::Generic, false) } } @@ -289,7 +290,7 @@ mod tests { ); // valid in postgresql dialect - validator.set_dialect("postgresql"); + validator.set_dialect(&Dialect::PostgreSQL); let result = readline_direct(Cursor::new(r"select 1 # 2;".as_bytes()), &validator)?; assert!(matches!(result, ValidationResult::Valid(None))); diff --git a/datafusion-cli/src/highlighter.rs b/datafusion-cli/src/highlighter.rs index 7a886b94740bd..f4e57a2e3593a 100644 --- a/datafusion-cli/src/highlighter.rs +++ b/datafusion-cli/src/highlighter.rs @@ -27,6 +27,7 @@ use datafusion::sql::sqlparser::{ keywords::Keyword, tokenizer::{Token, Tokenizer}, }; +use datafusion_common::config; use rustyline::highlight::{CmdKind, Highlighter}; /// The syntax highlighter. @@ -36,7 +37,7 @@ pub struct SyntaxHighlighter { } impl SyntaxHighlighter { - pub fn new(dialect: &str) -> Self { + pub fn new(dialect: &config::Dialect) -> Self { let dialect = dialect_from_str(dialect).unwrap_or(Box::new(GenericDialect {})); Self { dialect } } @@ -93,13 +94,14 @@ impl Color { #[cfg(test)] mod tests { + use super::config::Dialect; use super::SyntaxHighlighter; use rustyline::highlight::Highlighter; #[test] fn highlighter_valid() { let s = "SElect col_a from tab_1;"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!( "\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1;", @@ -110,7 +112,7 @@ mod tests { #[test] fn highlighter_valid_with_new_line() { let s = "SElect col_a from tab_1\n WHERE col_b = 'なにか';"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!( "\u{1b}[91mSElect\u{1b}[0m col_a \u{1b}[91mfrom\u{1b}[0m tab_1\n \u{1b}[91mWHERE\u{1b}[0m col_b = \u{1b}[92m'なにか'\u{1b}[0m;", @@ -121,7 +123,7 @@ mod tests { #[test] fn highlighter_invalid() { let s = "SElect col_a from tab_1 WHERE col_b = ';"; - let highlighter = SyntaxHighlighter::new("generic"); + let highlighter = SyntaxHighlighter::new(&Dialect::Generic); let out = highlighter.highlight(s, s.len()); assert_eq!("SElect col_a from tab_1 WHERE col_b = ';", out); } diff --git a/datafusion-examples/examples/remote_catalog.rs b/datafusion-examples/examples/remote_catalog.rs index 70c0963545e08..74575554ec0af 100644 --- a/datafusion-examples/examples/remote_catalog.rs +++ b/datafusion-examples/examples/remote_catalog.rs @@ -75,8 +75,8 @@ async fn main() -> Result<()> { let state = ctx.state(); // First, parse the SQL (but don't plan it / resolve any table references) - let dialect = state.config().options().sql_parser.dialect.as_str(); - let statement = state.sql_to_statement(sql, dialect)?; + let dialect = state.config().options().sql_parser.dialect; + let statement = state.sql_to_statement(sql, &dialect)?; // Find all `TableReferences` in the parsed queries. These correspond to the // tables referred to by the query (in this case diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 9bde637f43794..126935a1de90b 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -258,7 +258,7 @@ config_namespace! { /// Configure the SQL dialect used by DataFusion's parser; supported values include: Generic, /// MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB and Databricks. - pub dialect: String, default = "generic".to_string() + pub dialect: Dialect, default = Dialect::Generic // no need to lowercase because `sqlparser::dialect_from_str`] is case-insensitive /// If true, permit lengths for `VARCHAR` such as `VARCHAR(20)`, but @@ -292,6 +292,94 @@ config_namespace! { } } +/// This is the SQL dialect used by DataFusion's parser. +/// This mirrors [sqlparser::dialect::Dialect](https://docs.rs/sqlparser/latest/sqlparser/dialect/trait.Dialect.html) +/// trait in order to offer an easier API and avoid adding the `sqlparser` dependency +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub enum Dialect { + #[default] + Generic, + MySQL, + PostgreSQL, + Hive, + SQLite, + Snowflake, + Redshift, + MsSQL, + ClickHouse, + BigQuery, + Ansi, + DuckDB, + Databricks, +} + +impl AsRef for Dialect { + fn as_ref(&self) -> &str { + match self { + Self::Generic => "generic", + Self::MySQL => "mysql", + Self::PostgreSQL => "postgresql", + Self::Hive => "hive", + Self::SQLite => "sqlite", + Self::Snowflake => "snowflake", + Self::Redshift => "redshift", + Self::MsSQL => "mssql", + Self::ClickHouse => "clickhouse", + Self::BigQuery => "bigquery", + Self::Ansi => "ansi", + Self::DuckDB => "duckdb", + Self::Databricks => "databricks", + } + } +} + +impl FromStr for Dialect { + type Err = DataFusionError; + + fn from_str(s: &str) -> Result { + let value = match s.to_ascii_lowercase().as_str() { + "generic" => Self::Generic, + "mysql" => Self::MySQL, + "postgresql" | "postgres" => Self::PostgreSQL, + "hive" => Self::Hive, + "sqlite" => Self::SQLite, + "snowflake" => Self::Snowflake, + "redshift" => Self::Redshift, + "mssql" => Self::MsSQL, + "clickhouse" => Self::ClickHouse, + "bigquery" => Self::BigQuery, + "ansi" => Self::Ansi, + "duckdb" => Self::DuckDB, + "databricks" => Self::Databricks, + other => { + let error_message = format!( + "Invalid Dialect: {other}. Expected one of: Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks" + ); + return Err(DataFusionError::Configuration(error_message)); + } + }; + Ok(value) + } +} + +impl ConfigField for Dialect { + fn visit(&self, v: &mut V, key: &str, description: &'static str) { + v.some(key, self, description) + } + + fn set(&mut self, _: &str, value: &str) -> Result<()> { + *self = Self::from_str(value)?; + Ok(()) + } +} + +impl Display for Dialect { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let str = self.as_ref(); + write!(f, "{str}") + } +} + #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] pub enum SpillCompression { Zstd, diff --git a/datafusion/core/benches/sql_planner.rs b/datafusion/core/benches/sql_planner.rs index 3be8668b2b8c4..83563099cad67 100644 --- a/datafusion/core/benches/sql_planner.rs +++ b/datafusion/core/benches/sql_planner.rs @@ -30,7 +30,7 @@ use criterion::Bencher; use datafusion::datasource::MemTable; use datafusion::execution::context::SessionContext; use datafusion::prelude::DataFrame; -use datafusion_common::ScalarValue; +use datafusion_common::{config::Dialect, ScalarValue}; use datafusion_expr::Expr::Literal; use datafusion_expr::{cast, col, lit, not, try_cast, when}; use datafusion_functions::expr_fn::{ @@ -288,7 +288,10 @@ fn benchmark_with_param_values_many_columns( } // SELECT max(attr0), ..., max(attrN) FROM t1. let query = format!("SELECT {aggregates} FROM t1"); - let statement = ctx.state().sql_to_statement(&query, "Generic").unwrap(); + let statement = ctx + .state() + .sql_to_statement(&query, &Dialect::Generic) + .unwrap(); let plan = rt.block_on(async { ctx.state().statement_to_plan(statement).await.unwrap() }); b.iter(|| { diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index b04004dd495c8..6749ddd7ab8d5 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -30,15 +30,14 @@ use crate::datasource::provider_as_source; use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner}; use crate::execution::SessionStateDefaults; use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use arrow::datatypes::DataType; use datafusion_catalog::information_schema::{ InformationSchemaProvider, INFORMATION_SCHEMA, }; - -use arrow::datatypes::DataType; use datafusion_catalog::MemoryCatalogProviderList; use datafusion_catalog::{TableFunction, TableFunctionImpl}; use datafusion_common::alias::AliasGenerator; -use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; +use datafusion_common::config::{ConfigExtension, ConfigOptions, Dialect, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; use datafusion_common::tree_node::TreeNode; use datafusion_common::{ @@ -374,7 +373,7 @@ impl SessionState { pub fn sql_to_statement( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( @@ -411,7 +410,7 @@ impl SessionState { pub fn sql_to_expr( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { self.sql_to_expr_with_alias(sql, dialect).map(|x| x.expr) } @@ -423,7 +422,7 @@ impl SessionState { pub fn sql_to_expr_with_alias( &self, sql: &str, - dialect: &str, + dialect: &Dialect, ) -> datafusion_common::Result { let dialect = dialect_from_str(dialect).ok_or_else(|| { plan_datafusion_err!( @@ -527,8 +526,8 @@ impl SessionState { &self, sql: &str, ) -> datafusion_common::Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); - let statement = self.sql_to_statement(sql, dialect)?; + let dialect = self.config.options().sql_parser.dialect; + let statement = self.sql_to_statement(sql, &dialect)?; let plan = self.statement_to_plan(statement).await?; Ok(plan) } @@ -542,9 +541,9 @@ impl SessionState { sql: &str, df_schema: &DFSchema, ) -> datafusion_common::Result { - let dialect = self.config.options().sql_parser.dialect.as_str(); + let dialect = self.config.options().sql_parser.dialect; - let sql_expr = self.sql_to_expr_with_alias(sql, dialect)?; + let sql_expr = self.sql_to_expr_with_alias(sql, &dialect)?; let provider = SessionContextProvider { state: self, @@ -2034,6 +2033,7 @@ mod tests { use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_catalog::MemoryCatalogProviderList; + use datafusion_common::config::Dialect; use datafusion_common::DFSchema; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; @@ -2059,8 +2059,8 @@ mod tests { let sql = "[1,2,3]"; let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let df_schema = DFSchema::try_from(schema)?; - let dialect = state.config.options().sql_parser.dialect.as_str(); - let sql_expr = state.sql_to_expr(sql, dialect)?; + let dialect = state.config.options().sql_parser.dialect; + let sql_expr = state.sql_to_expr(sql, &dialect)?; let query = SqlToRel::new_with_options(&provider, state.get_parser_options()); query.sql_to_expr(sql_expr, &df_schema, &mut PlannerContext::new()) @@ -2218,7 +2218,8 @@ mod tests { } let state = &context_provider.state; - let statement = state.sql_to_statement("select count(*) from t", "mysql")?; + let statement = + state.sql_to_statement("select count(*) from t", &Dialect::MySQL)?; let plan = SqlToRel::new(&context_provider).statement_to_plan(statement)?; state.create_physical_plan(&plan).await } diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs index c8a4279a42110..e0a3e98604ae4 100644 --- a/datafusion/core/tests/user_defined/insert_operation.rs +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{any::Any, sync::Arc}; +use std::{any::Any, str::FromStr, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; @@ -24,6 +24,7 @@ use datafusion::{ prelude::{SessionConfig, SessionContext}, }; use datafusion_catalog::{Session, TableProvider}; +use datafusion_common::config::Dialect; use datafusion_expr::{dml::InsertOp, Expr, TableType}; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use datafusion_physical_plan::execution_plan::SchedulingType; @@ -63,7 +64,7 @@ async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { let mut config = SessionConfig::new(); let options = config.options_mut(); - options.sql_parser.dialect = dialect.into(); + options.sql_parser.dialect = Dialect::from_str(&dialect.into()).unwrap(); SessionContext::new_with_config(config) } diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index e93659872565b..0b9da1b5a86ae 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -86,6 +86,14 @@ See [issue #17601] for more details. [issue #17601]: https://github.com/apache/datafusion/issues/17601 +### `SessionState`'s `sql_to_statement` method takes `Dialect` rather than a `str` + +The `dialect` parameter of `sql_to_statement` method defined in `datafusion::execution::session_state::SessionState` +has changed from `&str` to `&Dialect`. +`Dialect` is an enum defined in the `datafusion-common` +crate under the `config` module that provides type safety +and better validation for SQL dialect selection + ## DataFusion `50.0.0` ### ListingTable automatically detects Hive Partitioned tables