diff --git a/wren-core-base/Cargo.toml b/wren-core-base/Cargo.toml index 8b22c1d51..d55cc168e 100644 --- a/wren-core-base/Cargo.toml +++ b/wren-core-base/Cargo.toml @@ -13,6 +13,8 @@ serde = { version = "1.0.201", features = ["derive", "rc"] } wren-manifest-macro = { path = "manifest-macro" } serde_json = { version = "1.0.117" } serde_with = { version = "3.11.0" } +sqlparser = { version = "0.55.0", features = ["visitor"] } + [lib] name = "wren_core_base" diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index 4be1ff49c..8099af2e0 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -481,6 +481,15 @@ mod test { let json_str = serde_json::to_string(&model).unwrap(); let actual: Arc = serde_json::from_str(&json_str).unwrap(); assert_eq!(actual, model); + + let model = ModelBuilder::new("test") + .table_reference(r#""Wren"."Public"."Source""#) + .column(ColumnBuilder::new("id", "integer").build()) + .build(); + + let json_str = serde_json::to_string(&model).unwrap(); + let actual: Arc = serde_json::from_str(&json_str).unwrap(); + assert_eq!(actual, model); } #[test] diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 9264ded1c..2d4896ca4 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -113,6 +113,8 @@ impl Display for DataSource { mod table_reference { use serde::{self, Deserialize, Deserializer, Serialize, Serializer}; + use crate::mdl::utils::{parse_identifiers_normalized, quote_identifier}; + #[derive(Deserialize, Serialize, Default)] struct TableReference { catalog: Option, @@ -133,7 +135,10 @@ mod table_reference { }| { [catalog, schema, table] .into_iter() - .filter_map(|s| s.filter(|x| !x.is_empty())) + .filter_map(|s| { + s.filter(|x| !x.is_empty()) + .map(|x| quote_identifier(&x).to_string()) + }) .collect::>() .join(".") }, @@ -146,7 +151,12 @@ mod table_reference { S: Serializer, { if let Some(table_ref) = table_ref { - let parts: Vec<&str> = table_ref.split('.').filter(|p| !p.is_empty()).collect(); + let parts: Vec = + parse_identifiers_normalized(table_ref, false).map_err(|e| { + serde::ser::Error::custom(format!( + "Failed to parse table reference: {table_ref}, error: {e}" + )) + })?; if parts.len() > 3 { return Err(serde::ser::Error::custom(format!( "Invalid table reference: {table_ref}" @@ -314,4 +324,16 @@ mod tests { assert_eq!(String::from_utf8(buf).unwrap(), *expected); }); } + + #[test] + fn test_case_sensitive() { + let table_ref = Some(r#""Catalog"."Schema"."Table""#.to_string()); + let mut buf = Vec::new(); + table_reference::serialize(&table_ref, &mut Serializer::new(&mut buf)).unwrap(); + let serialized = String::from_utf8(buf).unwrap(); + assert_eq!( + serialized, + r#"{"catalog":"Catalog","schema":"Schema","table":"Table"}"# + ); + } } diff --git a/wren-core-base/src/mdl/mod.rs b/wren-core-base/src/mdl/mod.rs index 2003099b7..e25ccb216 100644 --- a/wren-core-base/src/mdl/mod.rs +++ b/wren-core-base/src/mdl/mod.rs @@ -21,6 +21,7 @@ pub mod builder; pub mod cls; pub mod manifest; mod py_method; +mod utils; pub use builder::*; pub use manifest::*; diff --git a/wren-core-base/src/mdl/utils.rs b/wren-core-base/src/mdl/utils.rs new file mode 100644 index 000000000..af8df9eed --- /dev/null +++ b/wren-core-base/src/mdl/utils.rs @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::borrow::Cow; + +use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; + +pub(crate) fn parse_identifiers(s: &str) -> Result, sqlparser::parser::ParserError> { + let dialect = GenericDialect; + let mut parser = Parser::new(&dialect).try_with_sql(s)?; + let idents = parser.parse_multipart_identifier()?; + Ok(idents) +} + +pub(crate) fn parse_identifiers_normalized( + s: &str, + ignore_case: bool, +) -> Result, sqlparser::parser::ParserError> { + parse_identifiers(s).map(|v| { + v.into_iter() + .map(|id| match id.quote_style { + Some(_) => id.value, + None if ignore_case => id.value, + _ => id.value.to_ascii_lowercase(), + }) + .collect::>() + }) +} + +pub fn quote_identifier(s: &str) -> Cow { + if needs_quotes(s) { + Cow::Owned(format!("\"{}\"", s.replace('"', "\"\""))) + } else { + Cow::Borrowed(s) + } +} + +/// returns true if this identifier needs quotes +fn needs_quotes(s: &str) -> bool { + let mut chars = s.chars(); + + // first char can not be a number unless escaped + if let Some(first_char) = chars.next() { + if !(first_char.is_ascii_lowercase() || first_char == '_') { + return true; + } + } + + !chars.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_') +} diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index e7ef42176..e94697b30 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -777,7 +777,7 @@ dependencies = [ "parquet", "rand", "regex", - "sqlparser", + "sqlparser 0.54.0", "tempfile", "tokio", "url", @@ -844,7 +844,7 @@ dependencies = [ "parquet", "paste", "recursive", - "sqlparser", + "sqlparser 0.54.0", "tokio", "web-time", ] @@ -931,7 +931,7 @@ dependencies = [ "paste", "recursive", "serde_json", - "sqlparser", + "sqlparser 0.54.0", ] [[package]] @@ -1188,7 +1188,7 @@ dependencies = [ "log", "recursive", "regex", - "sqlparser", + "sqlparser 0.54.0", ] [[package]] @@ -2731,6 +2731,17 @@ dependencies = [ "sqlparser_derive", ] +[[package]] +name = "sqlparser" +version = "0.55.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4521174166bac1ff04fe16ef4524c70144cd29682a45978978ca3d7f4e0be11" +dependencies = [ + "log", + "recursive", + "sqlparser_derive", +] + [[package]] name = "sqlparser_derive" version = "0.3.0" @@ -3357,6 +3368,7 @@ dependencies = [ "serde", "serde_json", "serde_with", + "sqlparser 0.55.0", "wren-manifest-macro", ] diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 69244e1c8..0fa8c3e3b 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1439,6 +1439,92 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_uppercase_table_reference() -> Result<()> { + let mdl_json = r#" + { + "catalog": "wren", + "schema": "test", + "models": [ + { + "name": "customer", + "tableReference": { + "table": "CUSTOMER", + "schema": "test", + "catalog": "remote" + }, + "columns": [ + { + "name": "c_custkey", + "type": "int" + }, + { + "name": "c_name", + "type": "string" + } + ] + } + ] + } + "#; + let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); + let ctx = SessionContext::new(); + let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + assert_eq!( + result, + "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ + (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM \"remote\".test.\"CUSTOMER\" AS __source) AS customer) AS customer \ + WHERE customer.c_custkey = 1" + ); + Ok(()) + } + + #[tokio::test] + async fn test_unicode_table_reference() -> Result<()> { + let mdl_json = r#" + { + "catalog": "wren", + "schema": "test", + "models": [ + { + "name": "customer", + "tableReference": { + "table": "客戶", + "schema": "test", + "catalog": "遠端" + }, + "columns": [ + { + "name": "c_custkey", + "type": "int" + }, + { + "name": "c_name", + "type": "string" + } + ] + } + ] + } + "#; + let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); + let ctx = SessionContext::new(); + let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + assert_eq!( + result, + "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ + (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM \"遠端\".test.\"客戶\" AS __source) AS customer) AS customer \ + WHERE customer.c_custkey = 1" + ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));