Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions ibis-server/tests/routers/v3/connector/mysql/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pathlib

import pytest
import sqlalchemy
import pandas as pd
from testcontainers.mysql import MySqlContainer
from tests.conftest import file_path

pytestmark = pytest.mark.mysql

Expand All @@ -18,6 +21,37 @@ def pytest_collection_modifyitems(items):
@pytest.fixture(scope="session")
def mysql(request) -> MySqlContainer:
mysql = MySqlContainer(image="mysql:8.0.40", dialect="pymysql").start()
connection_url = mysql.get_connection_url()
engine = sqlalchemy.create_engine(connection_url)
pd.read_parquet(file_path("resource/tpch/data/orders.parquet")).to_sql(
"orders", engine, index=False
)
with engine.connect() as conn:
conn.execute(
sqlalchemy.text(
"""
CREATE TABLE json_test (
id BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY,
object_col JSON NOT NULL,
array_col JSON NOT NULL,
CHECK (JSON_TYPE(object_col) = 'OBJECT'),
CHECK (JSON_TYPE(array_col) = 'ARRAY')
) ENGINE=InnoDB;
"""
)
)
conn.execute(
sqlalchemy.text(
"""
INSERT INTO json_test (object_col, array_col) VALUES
('{"name": "Alice", "age": 30, "city": "New York"}', '["apple", "banana", "cherry"]'),
('{"name": "Bob", "age": 25, "city": "Los Angeles"}', '["dog", "cat", "mouse"]'),
('{"name": "Charlie", "age": 35, "city": "Chicago"}', '["red", "green", "blue"]');
"""
)
)
conn.commit()

request.addfinalizer(mysql.stop)
return mysql

Expand Down
67 changes: 67 additions & 0 deletions ibis-server/tests/routers/v3/connector/mysql/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import base64

import orjson
import pytest


from app.dependencies import X_WREN_FALLBACK_DISABLE
from tests.routers.v3.connector.mysql.conftest import base_url

manifest = {
"dataSource": "mysql",
"catalog": "my_catalog",
"schema": "my_schema",
"models": [
{
"name": "orders",
"tableReference": {
"table": "orders",
},
"columns": [
{"name": "o_orderkey", "type": "integer"},
{"name": "o_orderdate", "type": "date"},
],
},
],
}


@pytest.fixture(scope="module")
async def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


async def test_extract(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT EXTRACT(MONTH FROM o_orderdate) AS col FROM orders LIMIT 1",
},
headers={X_WREN_FALLBACK_DISABLE: "true"},
)
assert response.status_code == 200
result = response.json()
assert result == {
"columns": ["col"],
"data": [[1]],
"dtypes": {"col": "int32"},
}

response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT EXTRACT(WEEK FROM o_orderdate) AS col FROM orders LIMIT 1",
},
headers={X_WREN_FALLBACK_DISABLE: "true"},
)
assert response.status_code == 200
result = response.json()
assert result == {
"columns": ["col"],
"data": [[0]],
"dtypes": {"col": "int32"},
}
9 changes: 9 additions & 0 deletions wren-core/core/src/mdl/dialect/inner_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use datafusion::sql::sqlparser::ast::{
use datafusion::sql::unparser::ast::{
RelationBuilder, TableFactorBuilder, TableFunctionRelationBuilder,
};
use datafusion::sql::unparser::dialect::DateFieldExtractStyle;
use datafusion::sql::unparser::Unparser;
use regex::Regex;

Expand Down Expand Up @@ -95,6 +96,10 @@ pub trait InnerDialect: Send + Sync {
false
}

fn date_field_extract_style(&self) -> Option<DateFieldExtractStyle> {
None
}

/// Define the supported UDFs for the dialect which will be registered in the execution context.
fn supported_udfs(&self) -> Vec<Arc<datafusion::logical_expr::ScalarUDF>> {
scalar_functions()
Expand Down Expand Up @@ -144,6 +149,10 @@ impl InnerDialect for MySQLDialect {
_ => Ok(None),
}
}

fn date_field_extract_style(&self) -> Option<DateFieldExtractStyle> {
Some(DateFieldExtractStyle::Extract)
}
}

pub struct BigQueryDialect {}
Expand Down
121 changes: 119 additions & 2 deletions wren-core/core/src/mdl/dialect/wren_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@ use crate::mdl::manifest::DataSource;
use datafusion::common::Result;
use datafusion::logical_expr::sqlparser::keywords::ALL_KEYWORDS;
use datafusion::logical_expr::Expr;
use datafusion::scalar::ScalarValue;
use datafusion::sql::sqlparser::ast::{self, WindowFrameBound};
use datafusion::sql::unparser::dialect::{Dialect, IntervalStyle};
use datafusion::sql::sqlparser::tokenizer::Span;
use datafusion::sql::unparser::dialect::{
CharacterLengthStyle, DateFieldExtractStyle, Dialect, IntervalStyle,
};
use datafusion::sql::unparser::Unparser;
use regex::Regex;

Expand Down Expand Up @@ -67,7 +71,15 @@ impl Dialect for WrenDialect {
return Ok(Some(function));
}

Ok(None)
match func_name {
"date_part" => {
date_part_to_sql(unparser, self.date_field_extract_style(), args)
}
"character_length" => {
character_length_to_sql(unparser, self.character_length_style(), args)
}
_ => Ok(None),
}
}

fn unnest_as_table_factor(&self) -> bool {
Expand Down Expand Up @@ -119,6 +131,14 @@ impl Dialect for WrenDialect {
self.inner_dialect
.relation_alias_overrides(_relation_builder, _alias)
}

fn date_field_extract_style(&self) -> DateFieldExtractStyle {
if let Some(style) = self.inner_dialect.date_field_extract_style() {
style
} else {
DateFieldExtractStyle::DatePart
}
}
}

impl Default for WrenDialect {
Expand All @@ -139,3 +159,100 @@ fn non_lowercase(sql: &str) -> bool {
let lowercase = sql.to_lowercase();
lowercase != sql
}

/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
pub(crate) fn date_part_to_sql(
unparser: &Unparser,
style: DateFieldExtractStyle,
date_part_args: &[Expr],
) -> Result<Option<ast::Expr>> {
match (style, date_part_args.len()) {
(DateFieldExtractStyle::Extract, 2) => {
let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
let field = match field.to_lowercase().as_str() {
"year" => ast::DateTimeField::Year,
"month" => ast::DateTimeField::Month,
"day" => ast::DateTimeField::Day,
"hour" => ast::DateTimeField::Hour,
"minute" => ast::DateTimeField::Minute,
"second" => ast::DateTimeField::Second,
"week" => ast::DateTimeField::Week(None),
_ => return Ok(None),
};

return Ok(Some(ast::Expr::Extract {
field,
expr: Box::new(date_expr),
syntax: ast::ExtractSyntax::From,
}));
}
}
(DateFieldExtractStyle::Strftime, 2) => {
let column = unparser.expr_to_sql(&date_part_args[1])?;

if let Expr::Literal(ScalarValue::Utf8(Some(field)), _) = &date_part_args[0] {
let field = match field.to_lowercase().as_str() {
"year" => "%Y",
"month" => "%m",
"day" => "%d",
"hour" => "%H",
"minute" => "%M",
"second" => "%S",
"week" => "%U",
_ => return Ok(None),
};

return Ok(Some(ast::Expr::Function(ast::Function {
name: ast::ObjectName::from(vec![ast::Ident {
value: "strftime".to_string(),
quote_style: None,
span: Span::empty(),
}]),
args: ast::FunctionArguments::List(ast::FunctionArgumentList {
duplicate_treatment: None,
args: vec![
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
ast::Expr::value(ast::Value::SingleQuotedString(
field.to_string(),
)),
)),
ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)),
],
clauses: vec![],
}),
filter: None,
null_treatment: None,
over: None,
within_group: vec![],
parameters: ast::FunctionArguments::None,
uses_odbc_syntax: false,
})));
}
}
(DateFieldExtractStyle::DatePart, _) => {
return Ok(Some(
unparser.scalar_function_to_sql("date_part", date_part_args)?,
));
}
_ => {}
};

Ok(None)
}

pub(crate) fn character_length_to_sql(
unparser: &Unparser,
style: CharacterLengthStyle,
character_length_args: &[Expr],
) -> Result<Option<ast::Expr>> {
let func_name = match style {
CharacterLengthStyle::CharacterLength => "character_length",
CharacterLengthStyle::Length => "length",
};

Ok(Some(unparser.scalar_function_to_sql(
func_name,
character_length_args,
)?))
}
34 changes: 34 additions & 0 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3852,6 +3852,40 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_extract_roundtrip_mysql() -> Result<()> {
let ctx = create_wren_ctx(None, Some(&DataSource::MySQL));
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("orders")
.table_reference("orders")
.column(ColumnBuilder::new("o_orderdate", "date").build())
.build(),
)
.data_source(DataSource::MySQL)
.build();
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(
manifest,
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let headers = Arc::new(HashMap::default());
let sql = "SELECT EXTRACT(YEAR FROM o_orderdate) FROM orders";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@"SELECT EXTRACT(YEAR FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders"
);

let sql = "SELECT EXTRACT(WEEK FROM o_orderdate) FROM orders";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@"SELECT EXTRACT(WEEK FROM orders.o_orderdate) FROM (SELECT orders.o_orderdate FROM (SELECT __source.o_orderdate AS o_orderdate FROM orders AS __source) AS orders) AS orders"
);
Ok(())
}

#[tokio::test]
async fn test_bigquery_json() -> Result<()> {
let ctx = create_wren_ctx(None, Some(&DataSource::BigQuery));
Expand Down