Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def postgres(request) -> PostgresContainer:
"INSERT INTO null_test (id, letter) VALUES (1, 'one'), (2, 'two'), (NULL, 'three')"
)
)
conn.execute(sqlalchemy.text("CREATE TABLE 中文表 (欄位1 int, 欄位2 int)"))
conn.execute(
sqlalchemy.text("INSERT INTO 中文表 (欄位1, 欄位2) VALUES (1, 2), (3, 4)")
)

request.addfinalizer(pg.stop)
return pg
Expand Down
38 changes: 38 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,3 +1181,41 @@ async def test_cache_with_mixed_relevant_irrelevant_headers(
assert (
response3.headers["X-Cache-Hit"] == "false"
) # Should miss cache due to changed relevant header


async def test_query_unicode_table(client, connection_info):
manifest = {
"catalog": "wrenai",
"schema": "public",
"models": [
{
"name": "中文表",
"tableReference": {"schema": "public", "table": "中文表"},
"columns": [
{"name": "欄位1", "type": "int"},
{"name": "欄位2", "type": "int"},
],
}
],
"dataSource": "postgres",
}

manifest_str = base64.b64encode(orjson.dumps(manifest)).decode("utf-8")

response = await client.post(
url=f"{base_url}/query?cacheEnable=true",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT 欄位1, 欄位2 FROM 中文表 LIMIT 1",
},
headers={X_WREN_FALLBACK_DISABLE: "true"},
)

assert response.status_code == 200
result = response.json()
assert result["data"][0] == [1, 2]
assert result["dtypes"] == {
"欄位1": "int32",
"欄位2": "int32",
}
4 changes: 4 additions & 0 deletions wren-core/core/src/mdl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ pub async fn create_ctx_with_mdl(
"datafusion.sql_parser.default_null_ordering",
&ScalarValue::Utf8(Some("nulls_last".to_string())),
)
.set(
"datafusion.sql_parser.enable_ident_normalization",
&ScalarValue::Utf8(Some("false".to_string())),
)
.with_create_default_catalog_and_schema(false)
.with_default_catalog_and_schema(
analyzed_mdl.wren_mdl.catalog(),
Expand Down
93 changes: 77 additions & 16 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,6 @@ pub fn transform_sql(
}

/// Transform the SQL based on the MDL with the SessionContext
/// Wren engine will normalize the SQL to the lower case to solve the case-sensitive
/// issue for the Wren view
pub async fn transform_sql_with_ctx(
ctx: &SessionContext,
analyzed_mdl: Arc<AnalyzedWrenMDL>,
Expand Down Expand Up @@ -736,7 +734,7 @@ mod test {
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let sql = r#"select * from "CTest"."STest"."Customer""#;
let sql = r#"select * from CTest.STest.Customer"#;
let actual = mdl::transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
Expand Down Expand Up @@ -786,7 +784,7 @@ mod test {
Arc::clone(&analyzed_mdl),
&functions,
Arc::new(HashMap::new()),
r#"select add_two("Custkey") from "Customer""#,
r#"select add_two(Custkey) from Customer"#,
)
.await?;
assert_snapshot!(actual, @"SELECT add_two(\"Customer\".\"Custkey\") FROM (SELECT \"Customer\".\"Custkey\" \
Expand Down Expand Up @@ -829,22 +827,22 @@ mod test {
.column(ColumnBuilder::new("名字", "string").build())
.column(
ColumnBuilder::new("name_append", "string")
.expression(r#""名字" || "名字""#)
.expression(r#"名字 || 名字"#)
.build(),
)
.column(
ColumnBuilder::new("group", "string")
.expression(r#""組別""#)
.expression(r#"組別"#)
.build(),
)
.column(
ColumnBuilder::new("subscribe", "int")
.expression(r#""訂閱數""#)
.expression(r#"訂閱數"#)
.build(),
)
.column(
ColumnBuilder::new("subscribe_plus", "int")
.expression(r#""訂閱數" + 1"#)
.expression(r#"訂閱數 + 1"#)
.build(),
)
.build(),
Expand Down Expand Up @@ -910,12 +908,12 @@ mod test {
.table_reference("artist")
.column(
ColumnBuilder::new("name_append", "string")
.expression(r#""名字" || "名字""#)
.expression(r#"名字 || 名字"#)
.build(),
)
.column(
ColumnBuilder::new("lower_name", "string")
.expression(r#"lower("名字")"#)
.expression(r#"lower(名字)"#)
.build(),
)
.build(),
Expand Down Expand Up @@ -974,7 +972,7 @@ mod test {
.column(ColumnBuilder::new("名字", "string").hidden(true).build())
.column(
ColumnBuilder::new("串接名字", "string")
.expression(r#""名字" || "名字""#)
.expression(r#"名字 || 名字"#)
.build(),
)
.build(),
Expand All @@ -986,7 +984,7 @@ mod test {
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let sql = r#"select "串接名字" from wren.test.artist"#;
let sql = r#"select 串接名字 from wren.test.artist"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
Expand Down Expand Up @@ -1059,7 +1057,7 @@ mod test {
Arc::new(HashMap::default()),
Mode::Unparse,
)?);
let sql = r#"select * from wren.test.artist where "名字" in (SELECT "名字" FROM wren.test.artist)"#;
let sql = r#"select * from wren.test.artist where 名字 in (SELECT 名字 FROM wren.test.artist)"#;
let actual = transform_sql_with_ctx(
&SessionContext::new(),
Arc::clone(&analyzed_mdl),
Expand Down Expand Up @@ -1360,7 +1358,7 @@ mod test {
)
.column(
ColumnBuilder::new("cast_timestamptz", "timestamptz")
.expression(r#"cast("出道時間" as timestamp with time zone)"#)
.expression(r#"cast(出道時間 as timestamp with time zone)"#)
.build(),
)
.build(),
Expand Down Expand Up @@ -2622,7 +2620,7 @@ mod test {
.add_row_level_access_control(
"rule",
vec![SessionProperty::new_required("預定組別A")],
"\"組別\" = @預定組別A",
"組別 = @預定組別A",
)
.build(),
)
Expand All @@ -2638,7 +2636,7 @@ mod test {
Mode::Unparse,
)?);

let sql = r#"SELECT "名字", "組別", "訂閱數" FROM "VTU藝人""#;
let sql = r#"SELECT 名字, 組別, 訂閱數 FROM VTU藝人"#;
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql)
.await?,
Expand Down Expand Up @@ -3464,6 +3462,69 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_ambiguous_table_name() -> Result<()> {
let ctx = SessionContext::new();
let manifest = ManifestBuilder::new()
.catalog("wren")
.schema("test")
.model(
ModelBuilder::new("customer")
.table_reference("customer")
.column(ColumnBuilder::new("c_name", "int").build())
.column(ColumnBuilder::new("C_name", "string").build())
.build(),
)
.model(
ModelBuilder::new("Customer")
.table_reference("customer")
.column(ColumnBuilder::new("c_name", "int").build())
.column(ColumnBuilder::new("C_name", "string").build())
.build(),
)
.build();

let headers = Arc::new(HashMap::default());
let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(
manifest,
Arc::clone(&headers),
Mode::Unparse,
)?);

let sql = "select c_name, C_name from customer";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@r#"SELECT customer.c_name, customer."C_name" FROM (SELECT customer."C_name", customer.c_name FROM (SELECT __source."C_name" AS "C_name", __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer"#
);

let sql = "select c_name, C_name from Customer";
assert_snapshot!(
transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::clone(&headers), sql).await?,
@r#"SELECT "Customer".c_name, "Customer"."C_name" FROM (SELECT "Customer"."C_name", "Customer".c_name FROM (SELECT __source."C_name" AS "C_name", __source.c_name AS c_name FROM customer AS __source) AS "Customer") AS "Customer""#
);

let sql = "select * from CUSTOMER";
match transform_sql_with_ctx(
&ctx,
Arc::clone(&analyzed_mdl),
&[],
Arc::clone(&headers),
sql,
)
.await
{
Ok(_) => {
panic!("Expected error, but got SQL");
}
Err(e) => assert_snapshot!(
e.to_string(),
@"Error during planning: table 'wren.test.CUSTOMER' not found"
),
}

Ok(())
}

/// Return a RecordBatch with made up data about customer
fn customer() -> RecordBatch {
let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3]));
Expand Down
Loading