diff --git a/ibis-server/tests/routers/v3/connector/mysql/conftest.py b/ibis-server/tests/routers/v3/connector/mysql/conftest.py index 54efe958f..e5de77bb5 100644 --- a/ibis-server/tests/routers/v3/connector/mysql/conftest.py +++ b/ibis-server/tests/routers/v3/connector/mysql/conftest.py @@ -1,7 +1,10 @@ import pathlib import pytest +import sqlalchemy +import pandas as pd from testcontainers.mysql import MySqlContainer +from tests.conftest import file_path pytestmark = pytest.mark.mysql @@ -18,6 +21,37 @@ def pytest_collection_modifyitems(items): @pytest.fixture(scope="session") def mysql(request) -> MySqlContainer: mysql = MySqlContainer(image="mysql:8.0.40", dialect="pymysql").start() + connection_url = mysql.get_connection_url() + engine = sqlalchemy.create_engine(connection_url) + pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql( + "orders", engine, index=False + ) + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + """ + CREATE TABLE json_test ( + id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, + object_col JSON NOT NULL, + array_col JSON NOT NULL, + CHECK (JSON_TYPE(object_col) = 'OBJECT'), + CHECK (JSON_TYPE(array_col) = 'ARRAY') + ) ENGINE=InnoDB; + """ + ) + ) + conn.execute( + sqlalchemy.text( + """ + INSERT INTO json_test (object_col, array_col) VALUES + ('{"name": "Alice", "age": 30, "city": "New York"}', '["apple", "banana", "cherry"]'), + ('{"name": "Bob", "age": 25, "city": "Los Angeles"}', '["dog", "cat", "mouse"]'), + ('{"name": "Charlie", "age": 35, "city": "Chicago"}', '["red", "green", "blue"]'); + """ + ) + ) + conn.commit() + request.addfinalizer(mysql.stop) return mysql diff --git a/ibis-server/tests/routers/v3/connector/mysql/test_query.py b/ibis-server/tests/routers/v3/connector/mysql/test_query.py new file mode 100644 index 000000000..9817e8787 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/mysql/test_query.py @@ -0,0 +1,67 @@ +import base64 + +import orjson +import pytest + + +from app.dependencies import X_WREN_FALLBACK_DISABLE +from tests.routers.v3.connector.mysql.conftest import base_url + +manifest = { + "dataSource": "mysql", + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "orders", + "tableReference": { + "table": "orders", + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_orderdate", "type": "date"}, + ], + }, + ], +} + + +@pytest.fixture(scope="module") +async def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_extract(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT EXTRACT(MONTH FROM o_orderdate) AS col FROM orders LIMIT 1", + }, + headers={X_WREN_FALLBACK_DISABLE: "true"}, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[1]], + "dtypes": {"col": "int32"}, + } + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT EXTRACT(WEEK FROM o_orderdate) AS col FROM orders LIMIT 1", + }, + headers={X_WREN_FALLBACK_DISABLE: "true"}, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[0]], + "dtypes": {"col": "int32"}, + } diff --git a/wren-core/core/src/mdl/dialect/inner_dialect.rs b/wren-core/core/src/mdl/dialect/inner_dialect.rs index 9c77c4c21..44e16fbe9 100644 --- a/wren-core/core/src/mdl/dialect/inner_dialect.rs +++ b/wren-core/core/src/mdl/dialect/inner_dialect.rs @@ -36,6 +36,7 @@ use datafusion::sql::sqlparser::ast::{ use datafusion::sql::unparser::ast::{ RelationBuilder, TableFactorBuilder, TableFunctionRelationBuilder, }; +use datafusion::sql::unparser::dialect::DateFieldExtractStyle; use datafusion::sql::unparser::Unparser; use regex::Regex; @@ -95,6 +96,10 @@ pub trait InnerDialect: Send + Sync { false } + fn date_field_extract_style(&self) -> Option { + None + } + /// Define the supported UDFs for the dialect which will be registered in the execution context. fn supported_udfs(&self) -> Vec> { scalar_functions() @@ -144,6 +149,10 @@ impl InnerDialect for MySQLDialect { _ => Ok(None), } } + + fn date_field_extract_style(&self) -> Option { + Some(DateFieldExtractStyle::Extract) + } } pub struct BigQueryDialect {} diff --git a/wren-core/core/src/mdl/dialect/wren_dialect.rs b/wren-core/core/src/mdl/dialect/wren_dialect.rs index 4f7e2ab68..605aff0de 100644 --- a/wren-core/core/src/mdl/dialect/wren_dialect.rs +++ b/wren-core/core/src/mdl/dialect/wren_dialect.rs @@ -21,8 +21,12 @@ use crate::mdl::manifest::DataSource; use datafusion::common::Result; use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS; use datafusion::logical_expr::Expr; +use datafusion::scalar::ScalarValue; use datafusion::sql::sqlparser::ast::{self, WindowFrameBound}; -use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle}; +use datafusion::sql::sqlparser::tokenizer::Span; +use datafusion::sql::unparser::dialect::{ + CharacterLengthStyle, DateFieldExtractStyle, Dialect, IntervalStyle, +}; use datafusion::sql::unparser::Unparser; use regex::Regex; @@ -67,7 +71,15 @@ impl Dialect for WrenDialect { return Ok(Some(function)); } - Ok(None) + match func_name { + "date_part" => { + date_part_to_sql(unparser, self.date_field_extract_style(), args) + } + "character_length" => { + character_length_to_sql(unparser, self.character_length_style(), args) + } + _ => Ok(None), + } } fn unnest_as_table_factor(&self) -> bool { @@ -119,6 +131,14 @@ impl Dialect for WrenDialect { self.inner_dialect .relation_alias_overrides(_relation_builder, _alias) } + + fn date_field_extract_style(&self) -> DateFieldExtractStyle { + if let Some(style) = self.inner_dialect.date_field_extract_style() { + style + } else { + DateFieldExtractStyle::DatePart + } + } } impl Default for WrenDialect { @@ -139,3 +159,100 @@ fn non_lowercase(sql: &str) -> bool { let lowercase = sql.to_lowercase(); lowercase != sql } + +/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style. +pub(crate) fn date_part_to_sql( + unparser: &Unparser, + style: DateFieldExtractStyle, + date_part_args: &[Expr], +) -> Result> { + match (style, date_part_args.len()) { + (DateFieldExtractStyle::Extract, 2) => { + let date_expr = unparser.expr_to_sql(&date_part_args[1])?; + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => ast::DateTimeField::Year, + "month" => ast::DateTimeField::Month, + "day" => ast::DateTimeField::Day, + "hour" => ast::DateTimeField::Hour, + "minute" => ast::DateTimeField::Minute, + "second" => ast::DateTimeField::Second, + "week" => ast::DateTimeField::Week(None), + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Extract { + field, + expr: Box::new(date_expr), + syntax: ast::ExtractSyntax::From, + })); + } + } + (DateFieldExtractStyle::Strftime, 2) => { + let column = unparser.expr_to_sql(&date_part_args[1])?; + + if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] { + let field = match field.to_lowercase().as_str() { + "year" => "%Y", + "month" => "%m", + "day" => "%d", + "hour" => "%H", + "minute" => "%M", + "second" => "%S", + "week" => "%U", + _ => return Ok(None), + }; + + return Ok(Some(ast::Expr::Function(ast::Function { + name: ast::ObjectName::from(vec![ast::Ident { + value: "strftime".to_string(), + quote_style: None, + span: Span::empty(), + }]), + args: ast::FunctionArguments::List(ast::FunctionArgumentList { + duplicate_treatment: None, + args: vec![ + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr( + ast::Expr::value(ast::Value::SingleQuotedString( + field.to_string(), + )), + )), + ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)), + ], + clauses: vec![], + }), + filter: None, + null_treatment: None, + over: None, + within_group: vec![], + parameters: ast::FunctionArguments::None, + uses_odbc_syntax: false, + }))); + } + } + (DateFieldExtractStyle::DatePart, _) => { + return Ok(Some( + unparser.scalar_function_to_sql("date_part", date_part_args)?, + )); + } + _ => {} + }; + + Ok(None) +} + +pub(crate) fn character_length_to_sql( + unparser: &Unparser, + style: CharacterLengthStyle, + character_length_args: &[Expr], +) -> Result> { + let func_name = match style { + CharacterLengthStyle::CharacterLength => "character_length", + CharacterLengthStyle::Length => "length", + }; + + Ok(Some(unparser.scalar_function_to_sql( + func_name, + character_length_args, + )?)) +} diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index a6140949f..02ab8fb2b 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -3852,6 +3852,40 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_extract_roundtrip_mysql() -> Result<()> { + let ctx = create_wren_ctx(None, Some(&DataSource::MySQL)); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("orders") + .table_reference("orders") + .column(ColumnBuilder::new("o_orderdate", "date").build()) + .build(), + ) + .data_source(DataSource::MySQL) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let headers = Arc::new(HashMap::default()); + let sql = "SELECT EXTRACT(YEAR FROM o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT EXTRACT(YEAR FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + + let sql = "SELECT EXTRACT(WEEK FROM o_orderdate) FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?, + @"SELECT EXTRACT(WEEK FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders" + ); + Ok(()) + } + #[tokio::test] async fn test_bigquery_json() -> Result<()> { let ctx = create_wren_ctx(None, Some(&DataSource::BigQuery));