From 09ae1ea442ba5a6362ae5dfe0c923428030ea263 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 18 Apr 2025 13:48:10 +0800 Subject: [PATCH 01/30] add row level access control --- wren-core-base/manifest-macro/src/lib.rs | 57 +++++++++++++++++ wren-core-base/src/mdl/builder.rs | 81 +++++++++++++++++++++++- wren-core-base/src/mdl/manifest.rs | 16 +++-- wren-core-base/tests/data/mdl.json | 33 ++++++++++ 4 files changed, 181 insertions(+), 6 deletions(-) diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 65f145879..7274fcc62 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -140,6 +140,8 @@ pub fn model(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream pub cached: bool, #[serde(default)] pub refresh_time: Option, + #[serde(default)] + pub row_level_access_controls: Vec, } }; proc_macro::TokenStream::from(expanded) @@ -163,6 +165,7 @@ pub fn column(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStrea #[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] #[serde(rename_all = "camelCase")] + #[allow(deprecated)] pub struct Column { pub name: String, pub r#type: String, @@ -177,6 +180,7 @@ pub fn column(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStrea pub expression: Option, #[serde(default, with = "bool_from_int")] pub is_hidden: bool, + #[deprecated] pub rls: Option, pub cls: Option, } @@ -354,6 +358,55 @@ pub fn view(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream } #[proc_macro] +pub fn row_level_access_control(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] + #[serde(rename_all = "camelCase")] + pub struct RowLevelAccessControl { + pub name: String, + #[serde(default)] + pub required_variables: Vec, + /// A string expression that can be evaluated to a boolean value + pub condition: String, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +pub fn session_variable(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] + #[serde(rename_all = "camelCase")] + pub struct SessionVariable { + pub name: String, + pub required: bool, + pub default_expr: Option, + } + }; + proc_macro::TokenStream::from(expanded) +} + +#[proc_macro] +#[deprecated] pub fn row_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(python_binding as LitBool); let python_binding = if input.value { @@ -366,8 +419,10 @@ pub fn row_level_security(python_binding: proc_macro::TokenStream) -> proc_macro let expanded = quote! { #python_binding #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] + #[deprecated] pub struct RowLevelSecurity { pub name: String, + #[allow(deprecated)] pub operator: RowLevelOperator, } }; @@ -375,6 +430,7 @@ pub fn row_level_security(python_binding: proc_macro::TokenStream) -> proc_macro } #[proc_macro] +#[deprecated] pub fn row_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(python_binding as LitBool); let python_binding = if input.value { @@ -386,6 +442,7 @@ pub fn row_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro }; let expanded = quote! { #python_binding + #[deprecated] #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum RowLevelOperator { diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index 8099af2e0..c5231a4f9 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -22,8 +22,10 @@ use crate::mdl::manifest::{ Column, DataSource, JoinType, Manifest, Metric, Model, Relationship, TimeGrain, TimeUnit, View, }; +#[allow(deprecated)] use crate::mdl::{ - ColumnLevelOperator, ColumnLevelSecurity, NormalizedExpr, RowLevelOperator, RowLevelSecurity, + ColumnLevelOperator, ColumnLevelSecurity, NormalizedExpr, RowLevelAccessControl, + RowLevelOperator, RowLevelSecurity, SessionVariable, }; use std::sync::Arc; @@ -109,6 +111,7 @@ impl ModelBuilder { primary_key: None, cached: false, refresh_time: None, + row_level_access_controls: vec![], }, } } @@ -148,15 +151,47 @@ impl ModelBuilder { self } + pub fn add_row_level_access_control( + mut self, + name: &str, + required_variables: Vec, + condition: &str, + ) -> Self { + let rule = RowLevelAccessControl { + name: name.to_string(), + required_variables, + condition: condition.to_string(), + }; + self.model.row_level_access_controls.push(rule); + self + } + pub fn build(self) -> Arc { Arc::new(self.model) } } +impl SessionVariable { + pub fn new_required(name: &str) -> Self { + SessionVariable { + name: name.to_string(), + required: true, + default_expr: None, + } + } + pub fn new_optional(name: &str, default_expr: Option) -> Self { + SessionVariable { + name: name.to_string(), + required: false, + default_expr, + } + } +} pub struct ColumnBuilder { pub column: Column, } +#[allow(deprecated)] impl ColumnBuilder { pub fn new(name: &str, r#type: &str) -> Self { Self { @@ -207,6 +242,7 @@ impl ColumnBuilder { self } + #[allow(deprecated)] pub fn row_level_security(mut self, name: &str, operator: RowLevelOperator) -> Self { self.column.rls = Some(RowLevelSecurity { name: name.to_string(), @@ -382,12 +418,16 @@ mod test { use crate::mdl::manifest::{ Column, DataSource, JoinType, Manifest, Metric, Model, Relationship, TimeUnit, View, }; - use crate::mdl::{ColumnLevelOperator, RowLevelOperator}; + use crate::mdl::ColumnLevelOperator; + #[allow(deprecated)] + use crate::mdl::RowLevelOperator; + use crate::mdl::SessionVariable; use std::fs; use std::path::PathBuf; use std::sync::Arc; #[test] + #[allow(deprecated)] fn test_column_roundtrip() { let expected = ColumnBuilder::new("id", "integer") .relationship("test") @@ -452,6 +492,24 @@ mod test { .primary_key("id") .cached(true) .refresh_time("1h") + .add_row_level_access_control( + "rule1", + vec![SessionVariable::new_required("session_id")], + "id = @session_id", + ) + .add_row_level_access_control( + "rule2", + vec![SessionVariable::new_optional("session_id_optional", None)], + "id = @session_id_optional", + ) + .add_row_level_access_control( + "rule3", + vec![SessionVariable::new_optional( + "session_id_default", + Some("1".to_string()), + )], + "id = @session_id_default", + ) .build(); let json_str = serde_json::to_string(&model).unwrap(); @@ -621,6 +679,7 @@ mod test { } #[test] + #[allow(deprecated)] fn test_json_serde() { let test_data: PathBuf = [env!("CARGO_MANIFEST_DIR"), "tests", "data", "mdl.json"] .iter() @@ -647,6 +706,24 @@ mod test { .relationship("CustomerOrders") .build(), ) + .add_row_level_access_control( + "rule1", + vec![SessionVariable::new_required("session_id")], + "c_custkey = @session_id", + ) + .add_row_level_access_control( + "rule2", + vec![SessionVariable::new_optional("session_id_optional", None)], + "c_custkey = @session_id_optional", + ) + .add_row_level_access_control( + "rule3", + vec![SessionVariable::new_optional( + "session_id_default", + Some("1".to_string()), + )], + "c_custkey = @session_id_default", + ) .primary_key("c_custkey") .build(), ) diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 2d4896ca4..b5ddb460a 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -20,13 +20,15 @@ use std::fmt::Display; use std::sync::Arc; #[cfg(not(feature = "python-binding"))] +#[allow(deprecated)] mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; use manifest_macro::{ column, column_level_operator, column_level_security, data_source, join_type, manifest, - metric, model, normalized_expr, normalized_expr_type, relationship, row_level_operator, - row_level_security, time_grain, time_unit, view, + metric, model, normalized_expr, normalized_expr_type, relationship, + row_level_access_control, row_level_operator, row_level_security, session_variable, + time_grain, time_unit, view, }; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -44,6 +46,8 @@ mod manifest_impl { join_type!(false); time_grain!(false); time_unit!(false); + row_level_access_control!(false); + session_variable!(false); row_level_security!(false); row_level_operator!(false); column_level_security!(false); @@ -53,13 +57,15 @@ mod manifest_impl { } #[cfg(feature = "python-binding")] +#[allow(deprecated)] mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; use manifest_macro::{ column, column_level_operator, column_level_security, data_source, join_type, manifest, - metric, model, normalized_expr, normalized_expr_type, relationship, row_level_operator, - row_level_security, time_grain, time_unit, view, + metric, model, normalized_expr, normalized_expr_type, relationship, + row_level_access_control, row_level_operator, row_level_security, session_variable, + time_grain, time_unit, view, }; use pyo3::pyclass; use serde::{Deserialize, Serialize}; @@ -79,6 +85,8 @@ mod manifest_impl { time_grain!(true); time_unit!(true); manifest!(true); + row_level_access_control!(true); + session_variable!(true); row_level_security!(true); row_level_operator!(true); column_level_security!(true); diff --git a/wren-core-base/tests/data/mdl.json b/wren-core-base/tests/data/mdl.json index e588b4045..65e76bc27 100644 --- a/wren-core-base/tests/data/mdl.json +++ b/wren-core-base/tests/data/mdl.json @@ -34,6 +34,39 @@ } } ], + "rowLevelAccessControls": [ + { + "name": "rule1", + "requiredVariables": [ + { + "name": "session_id", + "required": true + } + ], + "condition": "c_custkey = @session_id" + }, + { + "name": "rule2", + "requiredVariables": [ + { + "name": "session_id_optional", + "required": false + } + ], + "condition": "c_custkey = @session_id_optional" + }, + { + "name": "rule3", + "requiredVariables": [ + { + "name": "session_id_default", + "required": false, + "defaultExpr": "1" + } + ], + "condition": "c_custkey = @session_id_default" + } + ], "primaryKey": "c_custkey", "properties": { "description": "This is a customer table", From 5c7e68040a99b62e3deed72b443a08b0cbff72bc Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 18 Apr 2025 13:56:29 +0800 Subject: [PATCH 02/30] rename session variable to session properties --- wren-core-base/manifest-macro/src/lib.rs | 6 +++--- wren-core-base/src/mdl/builder.rs | 26 ++++++++++++------------ wren-core-base/src/mdl/manifest.rs | 8 ++++---- wren-core-base/tests/data/mdl.json | 6 +++--- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 7274fcc62..24be3c3b1 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -374,7 +374,7 @@ pub fn row_level_access_control(python_binding: proc_macro::TokenStream) -> proc pub struct RowLevelAccessControl { pub name: String, #[serde(default)] - pub required_variables: Vec, + pub required_properties: Vec, /// A string expression that can be evaluated to a boolean value pub condition: String, } @@ -383,7 +383,7 @@ pub fn row_level_access_control(python_binding: proc_macro::TokenStream) -> proc } #[proc_macro] -pub fn session_variable(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { +pub fn session_property(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(python_binding as LitBool); let python_binding = if input.value { quote! { @@ -396,7 +396,7 @@ pub fn session_variable(python_binding: proc_macro::TokenStream) -> proc_macro:: #python_binding #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] #[serde(rename_all = "camelCase")] - pub struct SessionVariable { + pub struct SessionProperty { pub name: String, pub required: bool, pub default_expr: Option, diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index c5231a4f9..da8261193 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -25,7 +25,7 @@ use crate::mdl::manifest::{ #[allow(deprecated)] use crate::mdl::{ ColumnLevelOperator, ColumnLevelSecurity, NormalizedExpr, RowLevelAccessControl, - RowLevelOperator, RowLevelSecurity, SessionVariable, + RowLevelOperator, RowLevelSecurity, SessionProperty, }; use std::sync::Arc; @@ -154,12 +154,12 @@ impl ModelBuilder { pub fn add_row_level_access_control( mut self, name: &str, - required_variables: Vec, + required_properties: Vec, condition: &str, ) -> Self { let rule = RowLevelAccessControl { name: name.to_string(), - required_variables, + required_properties, condition: condition.to_string(), }; self.model.row_level_access_controls.push(rule); @@ -171,16 +171,16 @@ impl ModelBuilder { } } -impl SessionVariable { +impl SessionProperty { pub fn new_required(name: &str) -> Self { - SessionVariable { + SessionProperty { name: name.to_string(), required: true, default_expr: None, } } pub fn new_optional(name: &str, default_expr: Option) -> Self { - SessionVariable { + SessionProperty { name: name.to_string(), required: false, default_expr, @@ -421,7 +421,7 @@ mod test { use crate::mdl::ColumnLevelOperator; #[allow(deprecated)] use crate::mdl::RowLevelOperator; - use crate::mdl::SessionVariable; + use crate::mdl::SessionProperty; use std::fs; use std::path::PathBuf; use std::sync::Arc; @@ -494,17 +494,17 @@ mod test { .refresh_time("1h") .add_row_level_access_control( "rule1", - vec![SessionVariable::new_required("session_id")], + vec![SessionProperty::new_required("session_id")], "id = @session_id", ) .add_row_level_access_control( "rule2", - vec![SessionVariable::new_optional("session_id_optional", None)], + vec![SessionProperty::new_optional("session_id_optional", None)], "id = @session_id_optional", ) .add_row_level_access_control( "rule3", - vec![SessionVariable::new_optional( + vec![SessionProperty::new_optional( "session_id_default", Some("1".to_string()), )], @@ -708,17 +708,17 @@ mod test { ) .add_row_level_access_control( "rule1", - vec![SessionVariable::new_required("session_id")], + vec![SessionProperty::new_required("session_id")], "c_custkey = @session_id", ) .add_row_level_access_control( "rule2", - vec![SessionVariable::new_optional("session_id_optional", None)], + vec![SessionProperty::new_optional("session_id_optional", None)], "c_custkey = @session_id_optional", ) .add_row_level_access_control( "rule3", - vec![SessionVariable::new_optional( + vec![SessionProperty::new_optional( "session_id_default", Some("1".to_string()), )], diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index b5ddb460a..a732bc6b0 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -27,7 +27,7 @@ mod manifest_impl { use manifest_macro::{ column, column_level_operator, column_level_security, data_source, join_type, manifest, metric, model, normalized_expr, normalized_expr_type, relationship, - row_level_access_control, row_level_operator, row_level_security, session_variable, + row_level_access_control, row_level_operator, row_level_security, session_property, time_grain, time_unit, view, }; use serde::{Deserialize, Serialize}; @@ -47,7 +47,7 @@ mod manifest_impl { time_grain!(false); time_unit!(false); row_level_access_control!(false); - session_variable!(false); + session_property!(false); row_level_security!(false); row_level_operator!(false); column_level_security!(false); @@ -64,7 +64,7 @@ mod manifest_impl { use manifest_macro::{ column, column_level_operator, column_level_security, data_source, join_type, manifest, metric, model, normalized_expr, normalized_expr_type, relationship, - row_level_access_control, row_level_operator, row_level_security, session_variable, + row_level_access_control, row_level_operator, row_level_security, session_property, time_grain, time_unit, view, }; use pyo3::pyclass; @@ -86,7 +86,7 @@ mod manifest_impl { time_unit!(true); manifest!(true); row_level_access_control!(true); - session_variable!(true); + session_property!(true); row_level_security!(true); row_level_operator!(true); column_level_security!(true); diff --git a/wren-core-base/tests/data/mdl.json b/wren-core-base/tests/data/mdl.json index 65e76bc27..b812add1d 100644 --- a/wren-core-base/tests/data/mdl.json +++ b/wren-core-base/tests/data/mdl.json @@ -37,7 +37,7 @@ "rowLevelAccessControls": [ { "name": "rule1", - "requiredVariables": [ + "requiredProperties": [ { "name": "session_id", "required": true @@ -47,7 +47,7 @@ }, { "name": "rule2", - "requiredVariables": [ + "requiredProperties": [ { "name": "session_id_optional", "required": false @@ -57,7 +57,7 @@ }, { "name": "rule3", - "requiredVariables": [ + "requiredProperties": [ { "name": "session_id_default", "required": false, From 1c380c83727f3bac2cc518b34f55c1f21d51a49b Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 22 Apr 2025 15:38:41 +0800 Subject: [PATCH 03/30] use insta for testing --- ibis-server/tools/query_local_run-v2.py | 90 ++++++++++ wren-core/Cargo.toml | 2 +- wren-core/core/Cargo.toml | 1 + wren-core/wren-example/examples/demo_site.rs | 128 +++++++++++++ .../wren-example/examples/plan-sql-json.rs | 169 ++++++++++++++++++ 5 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 ibis-server/tools/query_local_run-v2.py create mode 100644 wren-core/wren-example/examples/demo_site.rs create mode 100644 wren-core/wren-example/examples/plan-sql-json.rs diff --git a/ibis-server/tools/query_local_run-v2.py b/ibis-server/tools/query_local_run-v2.py new file mode 100644 index 000000000..31a8732b8 --- /dev/null +++ b/ibis-server/tools/query_local_run-v2.py @@ -0,0 +1,90 @@ +# +# The script below is a standalone script that can be used to run a SQL query locally. +# +# Argements: +# - sql: stdin input a SQL query +# +# Environment variables: +# - WREN_MANIFEST_JSON_PATH: path to the manifest JSON file +# - REMOTE_FUNCTION_LIST_PATH: path to the function list file +# - CONNECTION_INFO_PATH: path to the connection info file +# - DATA_SOURCE: data source name +# + +import base64 +import json +import os +import sqlglot +import sys + +from dotenv import load_dotenv +from wren_core import SessionContext +from app.mdl.java_engine import JavaEngineConnector +from app.model.data_source import BigQueryConnectionInfo, DataSource +from app.model.data_source import DataSourceExtension +from app.mdl.rewriter import Rewriter + +if sys.stdin.isatty(): + print("please provide the SQL query via stdin, e.g. `python query_local_run.py < test.sql`", file=sys.stderr) + sys.exit(1) + +sql = sys.stdin.read() + + +load_dotenv() +manifest_json_path = os.getenv("WREN_MANIFEST_JSON_PATH") +function_list_path = os.getenv("REMOTE_FUNCTION_LIST_PATH") +connection_info_path = os.getenv("CONNECTION_INFO_PATH") +data_source = os.getenv("DATA_SOURCE") + +# Welcome message +print("### Welcome to the Wren Core Query Runner ###") +print("#") +print("# Manifest JSON Path:", manifest_json_path) + +async def main(): + print("# Function List Path:", function_list_path) + print("# Connection Info Path:", connection_info_path) + print("# Data Source:", data_source) + print("# SQL Query:\n", sql) + print("#") + + # Read and encode the JSON data + with open(manifest_json_path) as file: + mdl = json.load(file) + # Convert to JSON string + json_str = json.dumps(mdl) + # Encode to base64 + encoded_str = base64.b64encode(json_str.encode("utf-8")).decode("utf-8") + + with open(connection_info_path) as file: + connection_info = json.load(file) + + print("### Starting the session context ###") + print("#") + rewriter = Rewriter(encoded_str, + data_source=DataSource[data_source], + java_engine_connector=JavaEngineConnector(os.getenv("WREN_ENGINE_ENDPOINT"))) + # session_context = SessionContext(encoded_str, function_list_path) + # planned_sql = session_context.transform_sql(sql) + planned_sql = await rewriter.rewrite(sql) + print("# Planned SQL:\n", planned_sql) + + # Transpile the planned SQL + # dialect_sql = sqlglot.transpile(planned_sql, read="trino", write=data_source)[0] + # print("# Dialect SQL:\n", dialect_sql) + print("#") + + if data_source == "bigquery": + connection_info = BigQueryConnectionInfo.model_validate_json(json.dumps(connection_info)) + connection = DataSourceExtension.get_bigquery_connection(connection_info) + df = connection.sql(planned_sql).limit(10).to_pandas() + print("### Result ###") + print("") + print(df) + else: + print("Unsupported data source:", data_source) + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) \ No newline at end of file diff --git a/wren-core/Cargo.toml b/wren-core/Cargo.toml index 444ee0dcc..e1d2fe0b3 100644 --- a/wren-core/Cargo.toml +++ b/wren-core/Cargo.toml @@ -14,10 +14,10 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.88" -# We require the latest sqlparser-rs to support the latest SQL syntax datafusion = { git = "https://github.com/Canner/datafusion.git", branch = "v46.0.1" } env_logger = "0.11.3" hashbrown = "0.15.2" +insta = { version = "1.41.1" } log = { version = "0.4.14" } serde = { version = "1.0.201", features = ["derive", "rc"] } serde_json = { version = "1.0.117" } diff --git a/wren-core/core/Cargo.toml b/wren-core/core/Cargo.toml index 1efcc116a..86fe341ca 100644 --- a/wren-core/core/Cargo.toml +++ b/wren-core/core/Cargo.toml @@ -24,6 +24,7 @@ datafusion = { workspace = true, features = [ "unicode_expressions", ] } env_logger = { workspace = true } +insta = { workspace = true } log = { workspace = true } parking_lot = "0.12.3" petgraph = "0.7.1" diff --git a/wren-core/wren-example/examples/demo_site.rs b/wren-core/wren-example/examples/demo_site.rs new file mode 100644 index 000000000..9ea30d829 --- /dev/null +++ b/wren-core/wren-example/examples/demo_site.rs @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use datafusion::common::Result; +use datafusion::config::ConfigOptions; +use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; +use datafusion::logical_expr::ScalarUDF; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion::sql::sqlparser::ast::Visit; +use datafusion::sql::sqlparser::dialect::GenericDialect; +use std::ops::ControlFlow; +use std::sync::Arc; +use std::{fs, io}; +use wren_core::logical_plan::utils::try_map_data_type; +use wren_core::mdl::function::ByPassScalarUDF; +use wren_core::mdl::manifest::Manifest; +use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + + let mdl_json = "/Users/jax/git/wren-engine/ibis-server/etc.local/local_mdl.json"; + let json_string = fs::read_to_string(mdl_json).unwrap(); + let manifest: Manifest = serde_json::from_str(&json_string).unwrap(); + let mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + // + // let sql = "SELECT orders_key FROM (select * from orders limit 1000) as t"; + // let mut statements = wren_core::parser::Parser::parse_sql(&GenericDialect {}, sql)?; + // let pushdown_limit = 100; + // + // visit_statements_mut(&mut statements, |stmt| { + // if let Statement::Query(q) = stmt { + // if let Some(limit) = &q.limit { + // if let Expr::Value(Value::Number(n, is)) = limit { + // if n.parse::().unwrap() > pushdown_limit { + // q.limit = Some(Expr::Value(Value::Number( + // pushdown_limit.to_string(), + // is.clone(), + // ))); + // } + // } + // } else { + // q.limit = Some(Expr::Value(Value::Number( + // pushdown_limit.to_string(), + // false, + // ))); + // } + // } + // ControlFlow::<()>::Continue(()) + // }); + // print!("{}", statements[0]); + // let ctx = SessionContext::new(); + // let unparsed = match transform_sql_with_ctx(&ctx, mdl, &[], sql).await { + // Ok(sql) => println!("{}", sql), + // Err(e) => { + // eprintln!("Error: {}", e); + // return Ok(()); + // } + // }; + // let mut config = ConfigOptions::new(); + // config.execution.time_zone = Some("+03:00".to_string()); + // let session_config = SessionConfig::from(config); + // let state = SessionStateBuilder::new() + // .with_default_features() + // .with_config(session_config) + // .build(); + // let ctx = SessionContext::from(state); + // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + // "date_diff", + // map_data_type("bigint")?, + // ))); + // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + // "year", + // map_data_type("bigint")?, + // ))); + // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + // "month", + // map_data_type("bigint")?, + // ))); + // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + // "day", + // map_data_type("bigint")?, + // ))); + // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + // "age", + // map_data_type("interval")?, + // ))); + // + // // let sqls = fs::read_to_string("wren-example/data/demo_site.sql").unwrap(); + // let s = r#" + // select timestamp with time zone '2011-01-01' + // "#; + // // for (i, sql) in sqls.lines().enumerate() { + // let sqls = vec![s]; + // for (i, sql) in sqls.into_iter().enumerate() { + // if sql.starts_with("--") || sql.is_empty() { + // continue; + // } + // match transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], sql).await { + // Ok(sql) => { + // println!("{}", sql); + // } + // Err(e) => { + // println!("{}: {}", i + 1, sql); + // eprintln!("Error: {}", e); + // return Ok(()); + // } + // }; + // } + Ok(()) +} diff --git a/wren-core/wren-example/examples/plan-sql-json.rs b/wren-core/wren-example/examples/plan-sql-json.rs new file mode 100644 index 000000000..92adc0f0d --- /dev/null +++ b/wren-core/wren-example/examples/plan-sql-json.rs @@ -0,0 +1,169 @@ +use datafusion::common::Result; +use datafusion::execution::SessionStateBuilder; +use datafusion::functions::string::lower; +use datafusion::functions_aggregate::array_agg::array_agg_udaf; +use datafusion::prelude::SessionConfig; +use datafusion::prelude::SessionContext; +use serde::{Deserialize, Serialize}; +use std::str::FromStr; +use std::sync::Arc; +use wren_core::array::AsArray; +use wren_core::array::GenericByteArray; +use wren_core::array::GenericListArray; +use wren_core::datatypes::DataType; +use wren_core::datatypes::GenericStringType; +use wren_core::mdl::function::ByPassScalarUDF; +use wren_core::mdl::function::FunctionType; +use wren_core::mdl::function::RemoteFunction; +use wren_core::ScalarUDF; + +#[tokio::main] +async fn main() -> Result<()> { + env_logger::init(); + let sql = r#" + WITH inputs AS ( + SELECT + r.specific_name, + r.data_type as return_type, + pi.rid, + array_agg(pi.parameter_name order by pi.ordinal_position) as param_names, + array_agg(pi.data_type order by pi.ordinal_position) as param_types + FROM + information_schema.routines r + JOIN + information_schema.parameters pi ON r.specific_name = pi.specific_name AND pi.parameter_mode = 'IN' + GROUP BY 1, 2, 3 + ) + SELECT + r.routine_name as name, + i.param_names, + i.param_types, + r.data_type as return_type, + r.function_type, + r.description + FROM + information_schema.routines r + LEFT JOIN + inputs i ON r.specific_name = i.specific_name + "#; + let config = SessionConfig::new().with_information_schema(true); + let state: datafusion::execution::SessionState = SessionStateBuilder::new() + .with_default_features() + .with_config(config) + .build(); + let ctx = SessionContext::new_with_state(state); + // ctx.register_udaf(Arc::unwrap_or_clone(array_agg_udaf())); + // ctx.register_udf(ScalarUDF::new_from_impl(lower::LowerFunc::new())); + ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( + "add_two", + DataType::Int64, + ))); + let batches = ctx.sql(sql).await?.collect().await?; + let mut functions = vec![]; + + for batch in batches { + let name_array = batch.column(0).as_string::(); + let param_names_array = batch.column(1).as_list::(); + let param_types_array = batch.column(2).as_list::(); + let return_type_array = batch.column(3).as_string::(); + let function_type_array = batch.column(4).as_string::(); + let description_array = batch.column(5).as_string::(); + + for row in 0..batch.num_rows() { + let name = name_array.value(row).to_string(); + let _param_names = + to_string_vec(param_names_array.value(row).as_string::()); + let _param_types = + to_string_vec(param_types_array.value(row).as_string::()); + let return_type = return_type_array.value(row).to_string(); + let description = description_array.value(row).to_string(); + let function_type = function_type_array.value(row).to_string(); + + functions.push(RemoteFunction { + name, + param_names: None, + param_types: None, + return_type, + description: Some(description), + function_type: FunctionType::from_str(&function_type).unwrap(), + }); + } + } + functions + .iter() + .filter(|f| f.name == "add_two") + .for_each(|f| { + println!("{:?}", f); + }); + Ok(()) +} + +fn to_string_vec( + array: &GenericByteArray>, +) -> Vec> { + array + .iter() + .map(|s| s.map(|s| s.to_string())) + .collect::>>() +} + +fn read_remote_function_list(path: &str) -> Vec { + csv::Reader::from_path(path) + .unwrap() + .into_deserialize::() + .filter_map(Result::ok) + .map(|f| RemoteFunction::from(f)) + .collect::>() +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct PyRemoteFunction { + pub function_type: String, + pub name: String, + pub return_type: Option, + /// It's a comma separated string of parameter names + pub param_names: Option, + /// It's a comma separated string of parameter types + pub param_types: Option, + pub description: Option, +} + +impl From for wren_core::mdl::function::RemoteFunction { + fn from( + remote_function: PyRemoteFunction, + ) -> wren_core::mdl::function::RemoteFunction { + let param_names = remote_function.param_names.map(|names| { + names + .split(",") + .map(|name| { + if name.is_empty() { + None + } else { + Some(name.to_string()) + } + }) + .collect::>>() + }); + let param_types = remote_function.param_types.map(|types| { + types + .split(",") + .map(|t| { + if t.is_empty() { + None + } else { + Some(t.to_string()) + } + }) + .collect::>>() + }); + wren_core::mdl::function::RemoteFunction { + function_type: FunctionType::from_str(&remote_function.function_type) + .unwrap(), + name: remote_function.name, + return_type: remote_function.return_type.unwrap_or("string".to_string()), + param_names, + param_types, + description: remote_function.description, + } + } +} From d13e1b8fd4f7c6ee0d80fe13a5d9fa5b5d50f83f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 22 Apr 2025 15:49:32 +0800 Subject: [PATCH 04/30] add seesion properties --- wren-core/benchmarks/src/tpch/run.rs | 10 +- .../src/logical_plan/analyze/model_anlayze.rs | 6 + wren-core/core/src/mdl/context.rs | 12 + wren-core/core/src/mdl/mod.rs | 240 ++++++++++++++---- wren-core/sqllogictest/src/test_context.rs | 16 +- .../calculation-invoke-calculation.rs | 2 + .../wren-example/examples/datafusion-apply.rs | 3 +- wren-core/wren-example/examples/plan-sql.rs | 11 +- .../examples/to-many-calculation.rs | 3 +- wren-core/wren-example/examples/view.rs | 11 +- 10 files changed, 258 insertions(+), 56 deletions(-) diff --git a/wren-core/benchmarks/src/tpch/run.rs b/wren-core/benchmarks/src/tpch/run.rs index 7a99e22ef..aa42fa6c3 100644 --- a/wren-core/benchmarks/src/tpch/run.rs +++ b/wren-core/benchmarks/src/tpch/run.rs @@ -3,6 +3,7 @@ use crate::util::options::CommonOpt; use crate::util::run::BenchmarkRun; use datafusion::common::Result; use datafusion::prelude::SessionContext; +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::time::Instant; @@ -57,7 +58,14 @@ impl RunOpt { let start = Instant::now(); let sql = &get_query_sql(query_id)?; for query in sql { - transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], query).await?; + transform_sql_with_ctx( + &ctx, + Arc::clone(&mdl), + &[], + HashMap::new(), + query, + ) + .await?; } let elapsed = start.elapsed(); //.as_secs_f64() * 1000.0; diff --git a/wren-core/core/src/logical_plan/analyze/model_anlayze.rs b/wren-core/core/src/logical_plan/analyze/model_anlayze.rs index 7b2ddec51..a978fe38b 100644 --- a/wren-core/core/src/logical_plan/analyze/model_anlayze.rs +++ b/wren-core/core/src/logical_plan/analyze/model_anlayze.rs @@ -1,5 +1,6 @@ use crate::logical_plan::analyze::plan::ModelPlanNode; use crate::logical_plan::utils::{belong_to_mdl, expr_to_columns}; +use crate::mdl::context::SessionPropertiesRef; use crate::mdl::utils::quoted; use crate::mdl::{AnalyzedWrenMDL, Dataset, SessionStateRef}; use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -33,6 +34,7 @@ use std::sync::Arc; pub struct ModelAnalyzeRule { analyzed_wren_mdl: Arc, session_state: SessionStateRef, + properties: SessionPropertiesRef, } impl Debug for ModelAnalyzeRule { @@ -67,10 +69,12 @@ impl ModelAnalyzeRule { pub fn new( analyzed_wren_mdl: Arc, session_state: SessionStateRef, + properties: SessionPropertiesRef, ) -> Self { Self { analyzed_wren_mdl, session_state, + properties, } } @@ -449,6 +453,7 @@ impl ModelAnalyzeRule { Some(LogicalPlan::TableScan(table_scan.clone())), Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), + Arc::clone(&self.properties), )?), }); let subquery = LogicalPlanBuilder::from(model_plan) @@ -505,6 +510,7 @@ impl ModelAnalyzeRule { None, Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), + Arc::clone(&self.properties), )?), }); let subquery = diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index b8cd9f2ae..5ab4f6027 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -1,4 +1,5 @@ use std::any::Any; +use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; @@ -41,10 +42,13 @@ use datafusion::prelude::SessionContext; use datafusion::sql::TableReference; use parking_lot::RwLock; +pub type SessionPropertiesRef = Arc>>; + /// Apply Wren Rules to the context for sql generation. pub async fn create_ctx_with_mdl( ctx: &SessionContext, analyzed_mdl: Arc, + properties: SessionPropertiesRef, is_local_runtime: bool, ) -> Result { let config = ctx @@ -68,6 +72,7 @@ pub async fn create_ctx_with_mdl( new_state.with_analyzer_rules(analyze_rule_for_local_runtime( Arc::clone(&analyzed_mdl), reset_default_catalog_schema.clone(), + properties, )) // The plan will be executed locally, so apply the default optimizer rules } else { @@ -75,6 +80,7 @@ pub async fn create_ctx_with_mdl( .with_analyzer_rules(analyze_rule_for_unparsing( Arc::clone(&analyzed_mdl), reset_default_catalog_schema.clone(), + Arc::clone(&properties), )) .with_optimizer_rules(optimize_rule_for_unparsing()) }; @@ -89,6 +95,7 @@ pub async fn create_ctx_with_mdl( fn analyze_rule_for_local_runtime( analyzed_mdl: Arc, session_state_ref: SessionStateRef, + properties: SessionPropertiesRef, ) -> Vec> { vec![ // To align the lastest change in datafusion, apply this this rule first. @@ -101,10 +108,12 @@ fn analyze_rule_for_local_runtime( Arc::new(ModelAnalyzeRule::new( Arc::clone(&analyzed_mdl), Arc::clone(&session_state_ref), + Arc::clone(&properties), )), Arc::new(ModelGenerationRule::new( Arc::clone(&analyzed_mdl), session_state_ref, + properties, )), Arc::new(InlineTableScan::new()), // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. @@ -118,6 +127,7 @@ fn analyze_rule_for_local_runtime( fn analyze_rule_for_unparsing( analyzed_mdl: Arc, session_state_ref: SessionStateRef, + properties: SessionPropertiesRef, ) -> Vec> { vec![ // To align the lastest change in datafusion, apply this this rule first. @@ -130,10 +140,12 @@ fn analyze_rule_for_unparsing( Arc::new(ModelAnalyzeRule::new( Arc::clone(&analyzed_mdl), Arc::clone(&session_state_ref), + Arc::clone(&properties), )), Arc::new(ModelGenerationRule::new( Arc::clone(&analyzed_mdl), session_state_ref, + properties, )), Arc::new(InlineTableScan::new()), // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 0fa8c3e3b..bdf215c5b 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1,7 +1,6 @@ use crate::logical_plan::utils::{from_qualified_name_str, try_map_data_type}; use crate::mdl::builder::ManifestBuilder; use crate::mdl::context::{create_ctx_with_mdl, WrenDataSource}; -use crate::mdl::dialect::WrenDialect; use crate::mdl::function::{ ByPassAggregateUDF, ByPassScalarUDF, ByPassWindowFunction, FunctionType, RemoteFunction, @@ -22,6 +21,7 @@ use datafusion::sql::sqlparser::dialect::dialect_from_str; use datafusion::sql::unparser::Unparser; use datafusion::sql::TableReference; pub use dataset::Dataset; +use dialect::WrenDialect; use log::{debug, info}; use manifest::Relationship; use parking_lot::RwLock; @@ -331,6 +331,7 @@ impl WrenMDL { pub fn transform_sql( analyzed_mdl: Arc, remote_functions: &[RemoteFunction], + properties: HashMap>, sql: &str, ) -> Result { let runtime = tokio::runtime::Runtime::new().unwrap(); @@ -338,6 +339,7 @@ pub fn transform_sql( &SessionContext::new(), analyzed_mdl, remote_functions, + properties, sql, )) } @@ -349,6 +351,7 @@ pub async fn transform_sql_with_ctx( ctx: &SessionContext, analyzed_mdl: Arc, remote_functions: &[RemoteFunction], + properties: HashMap>, sql: &str, ) -> Result { info!("wren-core received SQL: {}", sql); @@ -357,7 +360,9 @@ pub async fn transform_sql_with_ctx( register_remote_function(ctx, remote_function)?; Ok::<_, DataFusionError>(()) })?; - let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), false).await?; + let properties_ref = Arc::new(properties); + let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), properties_ref, false) + .await?; let plan = ctx.state().create_logical_plan(sql).await?; debug!("wren-core original plan:\n {plan}"); let analyzed = ctx.state().optimize(&plan)?; @@ -466,6 +471,7 @@ mod test { let _ = mdl::transform_sql( Arc::clone(&analyzed_mdl), &[], + HashMap::new(), "select o_orderkey + o_orderkey from test.test.orders", )?; Ok(()) @@ -503,6 +509,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -531,6 +538,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -558,6 +566,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -575,6 +584,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -613,6 +623,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -652,6 +663,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &functions, + HashMap::new(), r#"select add_two("Custkey") from "Customer""#, ) .await?; @@ -662,6 +674,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &functions, + HashMap::new(), r#"select median("Custkey") from "CTest"."STest"."Customer" group by "Name""#, ) .await?; @@ -721,6 +734,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -736,6 +750,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -748,6 +763,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -786,6 +802,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await @@ -801,6 +818,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await @@ -839,6 +857,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -849,6 +868,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -860,6 +880,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await.map_err(|e| { @@ -878,6 +899,7 @@ mod test { &SessionContext::new(), Arc::new(AnalyzedWrenMDL::default()), &[], + HashMap::new(), sql, ) .await?; @@ -907,6 +929,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -922,8 +945,14 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select count(*) from (select 1)"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; // TODO: BigQuery doesn't support the alias include invalid characters (e.g. `*`, `()`). // We should remove the invalid characters for the alias. assert_eq!(actual, "SELECT count(1) AS \"count(*)\" FROM (SELECT 1)"); @@ -956,18 +985,36 @@ mod test { let ctx = SessionContext::new(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select interval 1 day"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT INTERVAL 1 DAY"); let sql = "SELECT INTERVAL '1 YEAR 1 MONTH'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT INTERVAL 13 MONTH"); let sql = "SELECT INTERVAL '1' YEAR + INTERVAL '2' MONTH + INTERVAL '3' DAY"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( actual, "SELECT INTERVAL 12 MONTH + INTERVAL 2 MONTH + INTERVAL 3 DAY" @@ -981,7 +1028,8 @@ mod test { let manifest = ManifestBuilder::new().build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let sql = "select * from unnest([1, 2, 3])"; - let actual = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], sql).await?; + let actual = + transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; assert_eq!(actual, "SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM (SELECT UNNEST([1, 2, 3]) AS \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\")"); let manifest = ManifestBuilder::new() @@ -989,7 +1037,8 @@ mod test { .build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let sql = "select * from unnest([1, 2, 3])"; - let actual = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], sql).await?; + let actual = + transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; assert_eq!(actual, "SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM UNNEST([1, 2, 3])"); Ok(()) } @@ -999,13 +1048,25 @@ mod test { let ctx = SessionContext::new(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp '2011-01-01 18:00:00 +08:00'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00 +08:00\"\")\""); let sql = "select timestamp '2011-01-01 18:00:00 Asia/Taipei'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00 Asia/Taipei\"\")\""); Ok(()) } @@ -1018,14 +1079,26 @@ mod test { let ctx = SessionContext::new_with_config(session_config); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); let sql = "select timestamp '2011-01-01 18:00:00'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; // TIMESTAMP doesn't have timezone, so the timezone will be ignored assert_eq!(actual, "SELECT CAST('2011-01-01 18:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); let sql = "select timestamp with time zone '2011-01-01 18:00:00'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; // TIMESTAMP WITH TIME ZONE will be converted to the session timezone assert_eq!(actual, "SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); @@ -1036,14 +1109,26 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::default()); // TIMESTAMP WITH TIME ZONE will be converted to the session timezone with daylight saving (UTC -5) let sql = "select timestamp with time zone '2024-01-15 18:00:00'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT CAST('2024-01-15 23:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2024-01-15 18:00:00\"\")\""); // TIMESTAMP WITH TIME ZONE will be converted to the session timezone without daylight saving (UTC -4) let sql = "select timestamp with time zone '2024-07-15 18:00:00'"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT CAST('2024-07-15 22:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2024-07-15 18:00:00\"\")\""); Ok(()) } @@ -1078,6 +1163,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -1121,7 +1207,10 @@ mod test { let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, registers)?); - let ctx = create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), true).await?; + let properties_ref = Arc::new(HashMap::new()); + let ctx = + create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), properties_ref, true) + .await?; let sql = r#"select arrow_typeof(timestamp_col), arrow_typeof(timestamptz_col) from wren.test.timestamp_table limit 1"#; let result = ctx.sql(sql).await?.collect().await?; let expected = vec![ @@ -1163,6 +1252,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -1177,6 +1267,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -1190,6 +1281,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -1204,6 +1296,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], + HashMap::new(), sql, ) .await?; @@ -1231,8 +1324,14 @@ mod test { .build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let sql = "select list_col[1] from wren.test.list_table"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT list_table.list_col[1] FROM (SELECT list_table.list_col FROM \ (SELECT __source.list_col AS list_col FROM list_table AS __source) AS list_table) AS list_table"); Ok(()) @@ -1266,8 +1365,14 @@ mod test { .build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let sql = "select struct_col.float_field from wren.test.struct_table"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( actual, "SELECT struct_table.struct_col.float_field FROM \ @@ -1276,16 +1381,28 @@ mod test { ); let sql = "select struct_array_col[1].float_field from wren.test.struct_table"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT struct_table.struct_array_col[1].float_field FROM \ (SELECT struct_table.struct_array_col FROM (SELECT __source.struct_array_col AS struct_array_col \ FROM struct_table AS __source) AS struct_table) AS struct_table"); let sql = "select {float_field: 1.0, time_field: timestamp '2021-01-01 00:00:00'}"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(actual, "SELECT {float_field: 1.0, time_field: CAST('2021-01-01 00:00:00' AS TIMESTAMP)}"); let manifest = ManifestBuilder::new() @@ -1300,7 +1417,7 @@ mod test { .build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let sql = "select struct_col.float_field from wren.test.struct_table"; - let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql) + let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) .await .map_err(|e| { assert_eq!( @@ -1317,9 +1434,14 @@ mod test { let sql = "SELECT CAST(TIMESTAMP '2021-01-01 00:00:00' as TIMESTAMP WITH TIME ZONE) = \ CAST(TIMESTAMP '2021-01-01 00:00:00' as TIMESTAMP WITH TIME ZONE)"; - let result = - transform_sql_with_ctx(&ctx, Arc::new(AnalyzedWrenMDL::default()), &[], sql) - .await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::new(AnalyzedWrenMDL::default()), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!(result, "SELECT CAST(CAST('2021-01-01 00:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE) = \ CAST(CAST('2021-01-01 00:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE)"); Ok(()) @@ -1332,9 +1454,14 @@ mod test { SELECT 1 x, 'b' y UNION ALL SELECT 2 x, 'a' y UNION ALL SELECT 2 x, 'c' y)"#; - let result = - transform_sql_with_ctx(&ctx, Arc::new(AnalyzedWrenMDL::default()), &[], sql) - .await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::new(AnalyzedWrenMDL::default()), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( result, "SELECT x, y FROM (SELECT 1 AS x, 'a' AS y \ @@ -1352,7 +1479,8 @@ mod test { let ctx = SessionContext::new(); let expected = "SELECT trim(' abc')"; let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], expected).await?; + transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], HashMap::new(), expected) + .await?; assert_eq!(actual, expected); Ok(()) } @@ -1373,8 +1501,14 @@ mod test { .build(); let sql = r#"SELECT c_custkey, count(distinct c_name) FROM customer GROUP BY c_custkey"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( result, "SELECT customer.c_custkey, count(DISTINCT customer.c_name) FROM \ @@ -1401,8 +1535,14 @@ mod test { .build(); let sql = r#"SELECT c_custkey, (SELECT c_name FROM customer WHERE c_custkey = 1) FROM customer"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( result, "SELECT customer.c_custkey, (SELECT customer.c_name FROM (SELECT customer.c_custkey, customer.c_name \ @@ -1428,8 +1568,14 @@ mod test { .build(); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( result, "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ diff --git a/wren-core/sqllogictest/src/test_context.rs b/wren-core/sqllogictest/src/test_context.rs index c64d65272..8cf159da0 100644 --- a/wren-core/sqllogictest/src/test_context.rs +++ b/wren-core/sqllogictest/src/test_context.rs @@ -301,7 +301,13 @@ async fn register_ecommerce_mdl( manifest, register_tables, )?); - let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), true).await?; + let ctx = create_ctx_with_mdl( + ctx, + Arc::clone(&analyzed_mdl), + Arc::new(HashMap::new()), + true, + ) + .await?; Ok((ctx.to_owned(), analyzed_mdl)) } @@ -531,6 +537,12 @@ async fn register_tpch_mdl( manifest, register_tables, )?); - let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), true).await?; + let ctx = create_ctx_with_mdl( + ctx, + Arc::clone(&analyzed_mdl), + Arc::new(HashMap::new()), + true, + ) + .await?; Ok((ctx.to_owned(), analyzed_mdl)) } diff --git a/wren-core/wren-example/examples/calculation-invoke-calculation.rs b/wren-core/wren-example/examples/calculation-invoke-calculation.rs index 093d15764..f593fe75b 100644 --- a/wren-core/wren-example/examples/calculation-invoke-calculation.rs +++ b/wren-core/wren-example/examples/calculation-invoke-calculation.rs @@ -81,6 +81,7 @@ async fn main() -> Result<()> { &ctx, Arc::clone(&analyzed_mdl), &[], + HashMap::new(), "select totalprice from wrenai.public.customers", ) .await @@ -106,6 +107,7 @@ async fn main() -> Result<()> { &ctx, Arc::clone(&analyzed_mdl), &[], + HashMap::new(), "select customer_state_cf from wrenai.public.order_items", ) .await diff --git a/wren-core/wren-example/examples/datafusion-apply.rs b/wren-core/wren-example/examples/datafusion-apply.rs index 08aeff6b9..c4fedc1b5 100644 --- a/wren-core/wren-example/examples/datafusion-apply.rs +++ b/wren-core/wren-example/examples/datafusion-apply.rs @@ -78,7 +78,8 @@ async fn main() -> Result<()> { // TODO: there're some issue for optimize rules // let ctx = create_ctx_with_mdl(&ctx, analyzed_mdl).await?; let sql = "select * from wrenai.public.order_items"; - let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], sql).await?; + let sql = + transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; println!("Wren engine generated SQL: \n{}", sql); // create a plan to run a SQL query let df = match ctx.sql(&sql).await { diff --git a/wren-core/wren-example/examples/plan-sql.rs b/wren-core/wren-example/examples/plan-sql.rs index 7640681b4..f11f80b34 100644 --- a/wren-core/wren-example/examples/plan-sql.rs +++ b/wren-core/wren-example/examples/plan-sql.rs @@ -1,4 +1,5 @@ use datafusion::prelude::SessionContext; +use std::collections::HashMap; use std::sync::Arc; use wren_core::mdl::builder::{ ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, @@ -13,8 +14,14 @@ async fn main() -> datafusion::common::Result<()> { let sql = "select customer_state from wrenai.public.orders_model"; println!("Original SQL: \n{}", sql); - let sql = - transform_sql_with_ctx(&SessionContext::new(), analyzed_mdl, &[], sql).await?; + let sql = transform_sql_with_ctx( + &SessionContext::new(), + analyzed_mdl, + &[], + HashMap::new(), + sql, + ) + .await?; println!("Wren engine generated SQL: \n{}", sql); Ok(()) } diff --git a/wren-core/wren-example/examples/to-many-calculation.rs b/wren-core/wren-example/examples/to-many-calculation.rs index 2639487a2..cadc7bee0 100644 --- a/wren-core/wren-example/examples/to-many-calculation.rs +++ b/wren-core/wren-example/examples/to-many-calculation.rs @@ -76,7 +76,8 @@ async fn main() -> Result<()> { ]); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); - let ctx = create_ctx_with_mdl(&ctx, analyzed_mdl, true).await?; + let ctx = + create_ctx_with_mdl(&ctx, analyzed_mdl, Arc::new(HashMap::new()), true).await?; let df = match ctx .sql("select totalprice from wrenai.public.customers") .await diff --git a/wren-core/wren-example/examples/view.rs b/wren-core/wren-example/examples/view.rs index c1edd8e79..7b763bde7 100644 --- a/wren-core/wren-example/examples/view.rs +++ b/wren-core/wren-example/examples/view.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::sync::Arc; use datafusion::prelude::SessionContext; @@ -15,8 +16,14 @@ async fn main() -> datafusion::common::Result<()> { let sql = "select * from wrenai.public.customers_view"; println!("Original SQL: \n{}", sql); - let sql = - transform_sql_with_ctx(&SessionContext::new(), analyzed_mdl, &[], sql).await?; + let sql = transform_sql_with_ctx( + &SessionContext::new(), + analyzed_mdl, + &[], + HashMap::new(), + sql, + ) + .await?; println!("Wren engine generated SQL: \n{}", sql); Ok(()) } From 412ef6d407d96828ed4bd14e06f812a80b29a689 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 22 Apr 2025 15:51:11 +0800 Subject: [PATCH 05/30] implement row level access control --- wren-core-base/src/mdl/manifest.rs | 4 + .../logical_plan/analyze/access_control.rs | 515 ++++++++++++++++++ .../core/src/logical_plan/analyze/mod.rs | 1 + .../logical_plan/analyze/model_generation.rs | 84 ++- .../core/src/logical_plan/analyze/plan.rs | 47 +- .../logical_plan/analyze/relation_chain.rs | 8 +- .../examples/row_level_access_control.rs | 74 +++ 7 files changed, 708 insertions(+), 25 deletions(-) create mode 100644 wren-core/core/src/logical_plan/analyze/access_control.rs create mode 100644 wren-core/wren-example/examples/row_level_access_control.rs diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index a732bc6b0..460a28672 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -277,6 +277,10 @@ impl Model { pub fn table_reference(&self) -> &str { self.table_reference.as_deref().unwrap_or("") } + + pub fn row_level_access_controls(&self) -> &[RowLevelAccessControl] { + &self.row_level_access_controls + } } impl Column { diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs new file mode 100644 index 000000000..df27fad3d --- /dev/null +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -0,0 +1,515 @@ +use std::{ + collections::{HashMap, HashSet}, + ops::ControlFlow, + sync::Arc, +}; + +use datafusion::{ + common::{plan_err, Result, Spans}, + error::DataFusionError, + prelude::Expr, + sql::{ + parser::DFParserBuilder, + sqlparser::{ + ast::{self, visit_expressions, visit_expressions_mut, ExprWithAlias}, + dialect::GenericDialect, + }, + TableReference, + }, +}; +use wren_core_base::mdl::RowLevelAccessControl; +use wren_core_base::mdl::{Model, SessionProperty}; + +use crate::mdl::{context::SessionPropertiesRef, Dataset, SessionStateRef}; + +/// Collect the required field from the condition of row level access control rules. +pub fn collect_condition( + model: &Model, + condition: &str, +) -> Result<(Vec, Vec)> { + let mut conditions = vec![]; + let mut seesion_properties: HashSet = HashSet::new(); + let mut error: Option> = None; + let dialect = GenericDialect {}; + let mut parser = DFParserBuilder::new(condition) + .with_dialect(&dialect) + .build()?; + let expr = parser.parse_expr()?; + visit_expressions(&expr, |expr| { + if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { + if !value.starts_with("@") { + if model.get_column(value).is_none() { + error = Some(plan_err!( + "The column {} is not in the model {}", + value, + model.name() + )); + return ControlFlow::Break(()); + } + conditions.push(Expr::Column(datafusion::common::Column { + relation: Some(TableReference::bare(model.name())), + name: value.to_string(), + spans: Spans::new(), + })); + } else { + let session_property = value.trim_start_matches("@").to_string(); + if !seesion_properties.contains(&session_property) { + seesion_properties.insert(session_property); + } + } + } + ControlFlow::Continue(()) + }); + + if let Some(err) = error { + return err; + } + + Ok(( + conditions, + seesion_properties.into_iter().collect::>(), + )) +} + +/// Build the filter expression for the row level access control rule. +pub fn build_filter_expression( + session_state: &SessionStateRef, + model: Arc, + properties: &SessionPropertiesRef, + rule: &RowLevelAccessControl, +) -> Result { + let RowLevelAccessControl { + condition, + required_properties, + .. + } = rule; + let mut error: Option> = None; + let dialect = GenericDialect {}; + let mut parser = DFParserBuilder::new(condition) + .with_dialect(&dialect) + .build()?; + let mut expr = parser.parse_expr()?; + + visit_expressions_mut(&mut expr, |expr| { + if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { + if value.starts_with("@") { + let property_name = value.trim_start_matches("@").to_string(); + let Some(property_value) = properties.get(&property_name).or_else(|| { + required_properties + .iter() + .filter(|r| r.name == property_name && !r.required) + .map(|r| &r.default_expr) + .next() + }) else { + error = Some(plan_err!( + "The session property {} is not found in the session properties", + property_name + )); + return ControlFlow::Break(()); + }; + + let Some(property_value) = property_value else { + error = Some(plan_err!( + "The session property {} should not be null", + property_name + )); + return ControlFlow::Break(()); + }; + + if property_value.is_empty() { + error = Some(plan_err!( + "The session property {} should not be empty", + property_name + )); + return ControlFlow::Break(()); + } + + match parse_expr(property_value) { + Ok(parsed_expr) => { + *expr = parsed_expr.expr; + } + Err(e) => { + error = Some(plan_err!( + "The session property {} is not valid: {}", + property_name, + e + )); + return ControlFlow::Break(()); + } + } + } + } + ControlFlow::Continue(()) + }); + + if let Some(error) = error { + return error; + } + let df_schema = Dataset::Model(Arc::clone(&model)).to_qualified_schema()?; + session_state + .read() + .create_logical_expr(&expr.to_string(), &df_schema) +} + +fn parse_expr(expr: &str) -> Result { + let dialect = GenericDialect {}; + let mut parser = DFParserBuilder::new(expr).with_dialect(&dialect).build()?; + let expr = parser.parse_expr()?; + Ok(expr) +} + +/// Validate the input headers with the required properties. +/// If the result is false, the rules are not satisfied and it should be ignored. +/// +/// If the required property is not found in the headers, return an error. +/// If the required property is found in the headers, return true. +/// If the optional property is found in the headers, return true. +/// If the optional property is not found in the headers but has a default value, return true. +/// If the optional property is not found in the headers and has no default value, return false. +pub fn validate_rule( + required_properties: &[SessionProperty], + headers: &HashMap>, +) -> Result { + let exists = required_properties.iter().map(|property| { + if property.required { + if !is_property_present(headers, &property.name) { + return plan_err!( + "Row level access control property {} is required, but not found in headers", + property.name + ); + } + Ok(true) + } else { + let exist = is_property_present(headers, &property.name); + if exist || property.default_expr.is_some() { + Ok(true) + } else { + Ok(false) + } + } + }).collect::>>()?; + + Ok(exists.iter().all(|x| *x)) +} + +/// Check if the property is present in the headers and not empty +/// If the property is present and not empty, return true. +fn is_property_present( + headers: &HashMap>, + property_name: &str, +) -> bool { + headers + .get(property_name) + .map(|v| v.as_ref().is_some_and(|value| !value.is_empty())) + .unwrap_or(false) +} + +#[cfg(test)] +mod test { + use std::{collections::HashMap, sync::Arc}; + + use datafusion::{ + error::Result, + prelude::{Expr, SessionContext}, + sql::unparser::Unparser, + }; + use insta::assert_snapshot; + use wren_core_base::mdl::{ + ColumnBuilder, ModelBuilder, RowLevelAccessControl, SessionProperty, + }; + + use crate::logical_plan::analyze::access_control::{ + collect_condition, validate_rule, + }; + + use super::build_filter_expression; + + #[test] + pub fn test_collect_condition() -> Result<()> { + let model = ModelBuilder::new("model1") + .column(ColumnBuilder::new("id", "int").build()) + .column(ColumnBuilder::new("name", "varchar").build()) + .build(); + + let conditions = vec![ + "id = @session_id AND name = 'test'", + "id = @session_id /* comment */ AND name = 'test'", + "id = @session_id \nAND name = 'test'", + ]; + for condition in conditions { + let (required_exprs, session_properties) = + collect_condition(&model, condition)?; + assert_eq!(required_exprs.len(), 2); + let name = required_exprs + .into_iter() + .map(|e| e.schema_name().to_string()) + .collect::>(); + assert_eq!(name, vec!["model1.id", "model1.name"]); + assert_eq!(session_properties.len(), 1); + assert_eq!(session_properties[0], "session_id"); + } + + let condition = "not_found = @session_id AND name = 'test'"; + match collect_condition(&model, condition) { + Err(error) + if error.message() + == "The column not_found is not in the model model1" => {} + _ => panic!("should be error"), + }; + + Ok(()) + } + + #[test] + pub fn test_validate_rule() -> Result<()> { + // required property + assert!(validate_rule( + &[SessionProperty::new_required("session_id")], + &build_headers(&[("session_id".to_string(), Some("1".to_string()))]) + )?); + + match validate_rule( + &[SessionProperty::new_required("session_id")], + &build_headers(&[("session_id".to_string(), None)]), + ) { + Err(error) => { + assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + } + _ => panic!("should be error"), + } + + match validate_rule( + &[SessionProperty::new_required("session_id")], + &build_headers(&[("session_id".to_string(), Some("".to_string()))]), + ) { + Err(error) => { + assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + } + _ => panic!("should be error"), + } + + match validate_rule( + &[SessionProperty::new_required("session_id")], + &build_headers(&[]), + ) { + Err(error) => { + assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + } + _ => panic!("should be error"), + } + + // optional property with default + assert!(validate_rule( + &[SessionProperty::new_optional( + "session_id", + Some("1".to_string()) + )], + &build_headers(&[("session_id".to_string(), Some("2".to_string()))]) + )?); + + assert!(validate_rule( + &[SessionProperty::new_optional( + "session_id", + Some("1".to_string()) + )], + &build_headers(&[("session_id".to_string(), None)]) + )?); + + assert!(validate_rule( + &[SessionProperty::new_optional( + "session_id", + Some("1".to_string()) + )], + &build_headers(&[("session_id".to_string(), Some("".to_string()))]) + )?); + + assert!(validate_rule( + &[SessionProperty::new_optional( + "session_id", + Some("1".to_string()) + )], + &build_headers(&[]) + )?); + + // optional property without default + assert!(validate_rule( + &[SessionProperty::new_optional("session_id", None)], + &build_headers(&[("session_id".to_string(), Some("2".to_string()))]) + )?); + + // expected false + assert!(!validate_rule( + &[SessionProperty::new_optional("session_id", None)], + &build_headers(&[("session_id".to_string(), None)]) + )?); + + // expected false + assert!(!validate_rule( + &[SessionProperty::new_optional("session_id", None)], + &build_headers(&[("session_id".to_string(), Some("".to_string()))]) + )?); + + // expected false + assert!(!validate_rule( + &[SessionProperty::new_optional("session_id", None)], + &build_headers(&[]) + )?); + + assert!(validate_rule( + &[ + SessionProperty::new_required("session_id"), + SessionProperty::new_optional("session_id_1", None), + SessionProperty::new_optional("session_id_2", Some("1".to_string())) + ], + &build_headers(&[ + ("session_id".to_string(), Some("1".to_string())), + ("session_id_1".to_string(), Some("1".to_string())), + ("session_id_2".to_string(), Some("2".to_string())), + ]) + )?); + + // expected false + assert!(!validate_rule( + &[ + SessionProperty::new_required("session_id"), + SessionProperty::new_optional("session_id_1", None), + SessionProperty::new_optional("session_id_2", Some("1".to_string())) + ], + &build_headers(&[ + ("session_id".to_string(), Some("1".to_string())), + ("session_id_1".to_string(), None), + ("session_id_2".to_string(), Some("2".to_string())), + ]) + )?); + + assert!(validate_rule( + &[ + SessionProperty::new_required("session_id"), + SessionProperty::new_optional("session_id_1", None), + SessionProperty::new_optional("session_id_2", Some("1".to_string())) + ], + &build_headers(&[ + ("session_id".to_string(), Some("1".to_string())), + ("session_id_1".to_string(), Some("1".to_string())), + ("session_id_2".to_string(), None), + ]) + )?); + + match validate_rule( + &[ + SessionProperty::new_required("session_id"), + SessionProperty::new_optional("session_id_1", None), + SessionProperty::new_optional("session_id_2", Some("1".to_string())), + ], + &build_headers(&[ + ("session_id".to_string(), None), + ("session_id_1".to_string(), Some("1".to_string())), + ("session_id_2".to_string(), None), + ]), + ) { + Err(error) => { + assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + } + _ => panic!("should be error"), + } + + Ok(()) + } + + fn build_headers( + field: &[(String, Option)], + ) -> HashMap> { + let mut headers = HashMap::new(); + for (key, value) in field { + headers.insert(key.clone(), value.clone()); + } + headers + } + + #[test] + pub fn test_build_filter_expression() -> Result<()> { + let ctx = SessionContext::new(); + let state = ctx.state_ref(); + let model = ModelBuilder::new("m1") + .column(ColumnBuilder::new("id", "int").build()) + .column(ColumnBuilder::new("name", "varchar").build()) + .build(); + + let headers = Arc::new(build_headers(&[ + ("session_id".to_string(), Some("1".to_string())), + ("session_name".to_string(), Some("'test'".to_string())), + ])); + + let rule = RowLevelAccessControl { + condition: "id = @session_id AND name = @session_name".to_string(), + required_properties: vec![ + SessionProperty::new_required("session_id"), + SessionProperty::new_required("session_name"), + ], + name: "test".to_string(), + }; + + let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); + + let rule = RowLevelAccessControl { + condition: "id = @not_found AND name = @session_name".to_string(), + required_properties: vec![ + SessionProperty::new_required("session_id"), + SessionProperty::new_required("session_name"), + ], + name: "test".to_string(), + }; + + match build_filter_expression(&state, Arc::clone(&model), &headers, &rule) { + Err(error) => { + assert_snapshot!(error.to_string(), @"Error during planning: The session property not_found is not found in the session properties"); + } + _ => panic!("should be error"), + } + + let rule = RowLevelAccessControl { + condition: "id = @session_id AND name = @session_name".to_string(), + required_properties: vec![ + SessionProperty::new_required("session_id"), + SessionProperty::new_required("session_name"), + ], + name: "test".to_string(), + }; + + let headers = Arc::new(build_headers(&[( + "session_id".to_string(), + Some("1".to_string()), + )])); + match build_filter_expression(&state, Arc::clone(&model), &headers, &rule) { + Err(error) => { + assert_snapshot!(error.to_string(), @"Error during planning: The session property session_name is not found in the session properties"); + } + _ => panic!("should be error"), + } + + let rule = RowLevelAccessControl { + condition: "id = @session_id AND name = @session_name".to_string(), + required_properties: vec![ + SessionProperty::new_required("session_id"), + SessionProperty::new_optional("session_name", Some("'test'".to_string())), + ], + name: "test".to_string(), + }; + + let headers = Arc::new(build_headers(&[( + "session_id".to_string(), + Some("1".to_string()), + )])); + + let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); + + Ok(()) + } + + fn expr_to_sql(expr: &Expr) -> Result { + let unparser = Unparser::default().with_pretty(true); + unparser.expr_to_sql(expr).map(|sql| sql.to_string()) + } +} diff --git a/wren-core/core/src/logical_plan/analyze/mod.rs b/wren-core/core/src/logical_plan/analyze/mod.rs index 283fd714d..483123777 100644 --- a/wren-core/core/src/logical_plan/analyze/mod.rs +++ b/wren-core/core/src/logical_plan/analyze/mod.rs @@ -1,3 +1,4 @@ +mod access_control; pub mod expand_view; pub mod model_anlayze; pub mod model_generation; diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index 96f89d444..9b08b4bec 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -7,6 +7,7 @@ use crate::logical_plan::analyze::plan::{ use crate::logical_plan::utils::{ create_remote_table_source, eliminate_ambiguous_columns, rebase_column, }; +use crate::mdl::context::SessionPropertiesRef; use crate::mdl::manifest::Model; use crate::mdl::utils::quoted; use crate::mdl::{AnalyzedWrenMDL, SessionStateRef}; @@ -20,6 +21,9 @@ use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion::optimizer::analyzer::AnalyzerRule; use datafusion::physical_plan::internal_err; use datafusion::sql::TableReference; +use wren_core_base::mdl::RowLevelAccessControl; + +use super::access_control::{build_filter_expression, validate_rule}; pub const SOURCE_ALIAS: &str = "__source"; @@ -27,13 +31,19 @@ pub const SOURCE_ALIAS: &str = "__source"; pub struct ModelGenerationRule { analyzed_wren_mdl: Arc, session_state: SessionStateRef, + properties: SessionPropertiesRef, } impl ModelGenerationRule { - pub fn new(mdl: Arc, session_state: SessionStateRef) -> Self { + pub fn new( + mdl: Arc, + session_state: SessionStateRef, + properties: SessionPropertiesRef, + ) -> Self { Self { analyzed_wren_mdl: mdl, session_state, + properties, } } @@ -51,6 +61,7 @@ impl ModelGenerationRule { ModelGenerationRule::new( Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), + Arc::clone(&self.properties), ), &alias_generator, )?; @@ -65,22 +76,46 @@ impl ModelGenerationRule { model_plan.required_exprs.clone() }; let projections = eliminate_ambiguous_columns(projections); - let result = match source_plan { - Some(plan) => { - if model_plan.required_exprs.is_empty() { - plan + let mut builder = if let Some(plan) = source_plan { + LogicalPlanBuilder::from(plan) + } else { + return plan_err!("Failed to generate source plan"); + }; + + if !model_plan.required_exprs.is_empty() { + builder = builder.project(projections)? + } + + let filters: Vec> = model_plan + .model + .row_level_access_controls() + .iter() + .map(|rule| { + self.generate_row_level_access_control_filter( + Arc::clone(&model_plan.model), + rule, + ) + }) + .collect::>()?; + let rls_filter = filters + .into_iter() + .reduce(|acc, filter| { + if acc.is_none() { + filter + } else if let Some(filter) = filter { + Some(acc.unwrap().and(filter)) } else { - LogicalPlanBuilder::from(plan) - .project(projections)? - .build()? + acc } - } - _ => { - return plan_err!("Failed to generate source plan"); - } - }; + }) + .flatten(); + + if let Some(filter) = rls_filter { + builder = builder.filter(filter)? + } + // calculated field scope - Ok(Transformed::yes(result)) + Ok(Transformed::yes(builder.build()?)) } else if let Some(model_plan) = extension.node.as_any().downcast_ref::() { @@ -150,6 +185,7 @@ impl ModelGenerationRule { ModelGenerationRule::new( Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), + Arc::clone(&self.properties), ), &alias_generator, )?; @@ -210,7 +246,7 @@ impl ModelGenerationRule { let projection = eliminate_ambiguous_columns(projection); let alias = LogicalPlanBuilder::from(source_plan) .project(projection)? - .alias(quoted(&partial_model.model_node.plan_name))? + .alias(quoted(partial_model.model_node.plan_name()))? .build()?; Ok(Transformed::yes(alias)) } else { @@ -220,6 +256,24 @@ impl ModelGenerationRule { _ => Ok(Transformed::yes(plan.recompute_schema()?)), } } + + fn generate_row_level_access_control_filter( + &self, + model: Arc, + rule: &RowLevelAccessControl, + ) -> Result> { + if validate_rule(&rule.required_properties, &self.properties)? { + let filter = build_filter_expression( + &self.session_state, + model, + &self.properties, + rule, + )?; + Ok(Some(filter)) + } else { + Ok(None) + } + } } impl Debug for ModelGenerationRule { diff --git a/wren-core/core/src/logical_plan/analyze/plan.rs b/wren-core/core/src/logical_plan/analyze/plan.rs index ef9857b52..2be75dce4 100644 --- a/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/wren-core/core/src/logical_plan/analyze/plan.rs @@ -9,7 +9,7 @@ use datafusion::arrow::datatypes::Field; use datafusion::common::{ internal_err, plan_err, Column, DFSchema, DFSchemaRef, TableReference, }; -use datafusion::error::Result; +use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::expr::WildcardOptions; use datafusion::logical_expr::utils::find_aggregate_exprs; use datafusion::logical_expr::{ @@ -22,6 +22,7 @@ use crate::logical_plan::analyze::RelationChain; use crate::logical_plan::analyze::RelationChain::Start; use crate::logical_plan::utils::{from_qualified_name, try_map_data_type}; use crate::mdl; +use crate::mdl::context::SessionPropertiesRef; use crate::mdl::lineage::DatasetLink; use crate::mdl::manifest::{JoinType, Model}; use crate::mdl::utils::{ @@ -31,6 +32,8 @@ use crate::mdl::utils::{ use crate::mdl::Dataset; use crate::mdl::{AnalyzedWrenMDL, ColumnReference, SessionStateRef}; +use super::access_control::{collect_condition, validate_rule}; + #[derive(Debug)] pub(crate) enum WrenPlan { Calculation(Arc), @@ -55,7 +58,7 @@ impl WrenPlan { /// It only generates the top plan for the model, and the relation chain will generate the source plan. #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub(crate) struct ModelPlanNode { - pub(crate) plan_name: String, + pub(crate) model: Arc, pub(crate) required_exprs: Vec, pub(crate) relation_chain: Box, schema_ref: DFSchemaRef, @@ -69,8 +72,9 @@ impl ModelPlanNode { original_table_scan: Option, analyzed_wren_mdl: Arc, session_state: SessionStateRef, + properties: SessionPropertiesRef, ) -> Result { - ModelPlanNodeBuilder::new(analyzed_wren_mdl, session_state).build( + ModelPlanNodeBuilder::new(analyzed_wren_mdl, session_state, properties).build( model, required_fields, original_table_scan, @@ -78,7 +82,7 @@ impl ModelPlanNode { } pub fn plan_name(&self) -> &str { - &self.plan_name + self.model.name() } } @@ -98,12 +102,14 @@ struct ModelPlanNodeBuilder { fields: VecDeque<(Option, Arc)>, analyzed_wren_mdl: Arc, session_state: SessionStateRef, + properties: SessionPropertiesRef, } impl ModelPlanNodeBuilder { fn new( analyzed_wren_mdl: Arc, session_state: SessionStateRef, + properties: SessionPropertiesRef, ) -> Self { Self { required_exprs_buffer: BTreeSet::new(), @@ -113,6 +119,7 @@ impl ModelPlanNodeBuilder { fields: VecDeque::new(), analyzed_wren_mdl, session_state, + properties, } } @@ -128,6 +135,9 @@ impl ModelPlanNodeBuilder { model.name(), ); + let required_fields = + self.add_required_columns_from_session_properties(&model, required_fields)?; + let required_columns = model.get_physical_columns().into_iter().filter(|column| { required_fields @@ -322,6 +332,7 @@ impl ModelPlanNodeBuilder { &self.model_required_fields.clone(), Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), + Arc::clone(&self.properties), )?; for calculation_plan in calculate_iter { @@ -349,7 +360,7 @@ impl ModelPlanNodeBuilder { } Ok(ModelPlanNode { - plan_name: model.name().to_string(), + model, required_exprs: self .required_exprs_buffer .iter() @@ -362,6 +373,24 @@ impl ModelPlanNodeBuilder { }) } + fn add_required_columns_from_session_properties( + &self, + model: &Model, + required_fields: Vec, + ) -> Result> { + let mut required_fields = required_fields; + model + .row_level_access_controls() + .iter() + .try_for_each(|rule| { + if validate_rule(&rule.required_properties, &self.properties)? { + required_fields.extend(collect_condition(model, &rule.condition)?.0); + } + Ok::<_, DataFusionError>(()) + })?; + Ok(required_fields) + } + fn is_to_many_calculation(&self, expr: Expr) -> bool { !find_aggregate_exprs(&[expr]).is_empty() } @@ -448,6 +477,7 @@ impl ModelPlanNodeBuilder { &partial_model_required_fields, Arc::clone(&self.analyzed_wren_mdl), Arc::clone(&self.session_state), + Arc::clone(&self.properties), )?; let Some(column_rf) = self .analyzed_wren_mdl @@ -701,7 +731,8 @@ impl UserDefinedLogicalNodeCore for ModelPlanNode { write!( f, "Model: name={}, schema={}", - self.plan_name, self.schema_ref + self.model.name(), + self.schema_ref ) } @@ -711,7 +742,7 @@ impl UserDefinedLogicalNodeCore for ModelPlanNode { _: Vec, ) -> datafusion::common::Result { Ok(ModelPlanNode { - plan_name: self.plan_name.clone(), + model: self.model.clone(), required_exprs: self.required_exprs.clone(), relation_chain: self.relation_chain.clone(), schema_ref: self.schema_ref.clone(), @@ -1019,7 +1050,7 @@ impl UserDefinedLogicalNodeCore for PartialModelPlanNode { } fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "PartialModel: name={}", self.model_node.plan_name) + write!(f, "PartialModel: name={}", self.model_node.model.name()) } fn with_exprs_and_inputs( diff --git a/wren-core/core/src/logical_plan/analyze/relation_chain.rs b/wren-core/core/src/logical_plan/analyze/relation_chain.rs index 99d603080..4521b0f80 100644 --- a/wren-core/core/src/logical_plan/analyze/relation_chain.rs +++ b/wren-core/core/src/logical_plan/analyze/relation_chain.rs @@ -6,6 +6,7 @@ use crate::logical_plan::analyze::relation_chain::RelationChain::Start; use crate::logical_plan::utils::{ create_schema, eliminate_ambiguous_columns, rebase_column, }; +use crate::mdl::context::SessionPropertiesRef; use crate::mdl::lineage::DatasetLink; use crate::mdl::manifest::JoinType; use crate::mdl::utils::{qualify_name_from_column_name, quoted}; @@ -63,6 +64,7 @@ impl RelationChain { } } + #[allow(clippy::too_many_arguments)] pub fn with_chain( source: Self, mut start: NodeIndex, @@ -71,6 +73,7 @@ impl RelationChain { model_required_fields: &HashMap>, analyzed_wren_mdl: Arc, session_state_ref: SessionStateRef, + properties: SessionPropertiesRef, ) -> Result { let mut relation_chain = source; @@ -102,6 +105,7 @@ impl RelationChain { None, Arc::clone(&analyzed_wren_mdl), Arc::clone(&session_state_ref), + Arc::clone(&properties), )?; let df_schema = @@ -196,7 +200,7 @@ impl RelationChain { .map(|field| { col(format!( "{}.{}", - quoted(&model_plan.plan_name), + quoted(model_plan.plan_name()), quoted(field.name()), )) }) @@ -246,7 +250,7 @@ impl RelationChain { .map(|field| { col(format!( "{}.{}", - quoted(&partial_model_plan.model_node.plan_name), + quoted(partial_model_plan.model_node.plan_name()), quoted(field.name()), )) }) diff --git a/wren-core/wren-example/examples/row_level_access_control.rs b/wren-core/wren-example/examples/row_level_access_control.rs new file mode 100644 index 000000000..6123351a8 --- /dev/null +++ b/wren-core/wren-example/examples/row_level_access_control.rs @@ -0,0 +1,74 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use wren_core::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; +use wren_core::mdl::manifest::{Manifest, SessionProperty}; +use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; + +#[tokio::main] +async fn main() -> datafusion::common::Result<()> { + let manifest = init_manifest(); + let ctx = SessionContext::new(); + + ctx.register_csv( + "customers", + "sqllogictest/tests/resources/ecommerce/customers.csv", + CsvReadOptions::new(), + ) + .await?; + let customers_provider = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("customers") + .await? + .unwrap(); + let register = HashMap::from([( + "datafusion.public.customers".to_string(), + customers_provider, + )]); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); + let sql = "SELECT * FROM customers"; + + // carry the seesion property + let mut properties = HashMap::new(); + properties.insert("session_city".to_string(), Some("'Santa Ana'".to_string())); + + let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], properties, sql).await?; + println!("Wren engine generated SQL: \n{}", sql); + let df = match ctx.sql(&sql).await { + Ok(df) => df, + Err(e) => { + eprintln!("Error: {}", e); + return Err(e); + } + }; + match df.show().await { + Ok(_) => {} + Err(e) => eprintln!("Error: {}", e), + } + + Ok(()) +} + +fn init_manifest() -> Manifest { + ManifestBuilder::new() + .model( + ModelBuilder::new("customers") + .table_reference("datafusion.public.customers") + .column(ColumnBuilder::new("city", "varchar").build()) + .column(ColumnBuilder::new("id", "varchar").build()) + .column(ColumnBuilder::new("state", "varchar").build()) + .add_row_level_access_control( + "city rule", + vec![SessionProperty::new_required("session_city")], + "city = @session_city", + ) + .primary_key("id") + .build(), + ) + .build() +} From 2ca197a7d9787183aa00a0d4bcc5b8aa0359cbae Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 22 Apr 2025 16:28:29 +0800 Subject: [PATCH 06/30] refactor test to use insta --- wren-core/core/src/mdl/mod.rs | 156 ++++++++++++++++++---------------- 1 file changed, 82 insertions(+), 74 deletions(-) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index bdf215c5b..479dd7e83 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -448,12 +448,14 @@ mod test { use datafusion::arrow::array::{ ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray, }; - use datafusion::assert_batches_eq; + use datafusion::arrow::util::pretty::pretty_format_batches_with_options; + use datafusion::common::format::DEFAULT_FORMAT_OPTIONS; use datafusion::common::not_impl_err; use datafusion::common::Result; use datafusion::config::ConfigOptions; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion::sql::unparser::plan_to_sql; + use insta::assert_snapshot; use wren_core_base::mdl::DataSource; #[test] @@ -570,14 +572,13 @@ mod test { sql, ) .await?; - let expected = "SELECT \"profile\".totalcost FROM (SELECT totalcost.totalcost FROM \ + assert_snapshot!(result, @"SELECT \"profile\".totalcost FROM (SELECT totalcost.totalcost FROM \ (SELECT __relation__2.p_custkey AS p_custkey, sum(CAST(__relation__2.o_totalprice AS BIGINT)) AS totalcost FROM \ (SELECT __relation__1.c_custkey, orders.o_custkey, orders.o_totalprice, __relation__1.p_custkey FROM \ (SELECT __source.o_custkey AS o_custkey, __source.o_totalprice AS o_totalprice FROM orders AS __source) AS orders RIGHT JOIN \ (SELECT customer.c_custkey, \"profile\".p_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer RIGHT JOIN \ (SELECT __source.p_custkey AS p_custkey FROM \"profile\" AS __source) AS \"profile\" ON customer.c_custkey = \"profile\".p_custkey) AS __relation__1 \ - ON orders.o_custkey = __relation__1.c_custkey) AS __relation__2 GROUP BY __relation__2.p_custkey) AS totalcost) AS \"profile\""; - assert_eq!(result, expected); + ON orders.o_custkey = __relation__1.c_custkey) AS __relation__2 GROUP BY __relation__2.p_custkey) AS totalcost) AS \"profile\""); let sql = "select totalcost from profile where p_sex = 'M'"; let result = transform_sql_with_ctx( @@ -588,8 +589,8 @@ mod test { sql, ) .await?; - assert_eq!(result, - "SELECT \"profile\".totalcost FROM (SELECT __relation__1.p_sex, __relation__1.totalcost FROM \ + assert_snapshot!(result, + @"SELECT \"profile\".totalcost FROM (SELECT __relation__1.p_sex, __relation__1.totalcost FROM \ (SELECT totalcost.p_custkey, \"profile\".p_sex, totalcost.totalcost FROM (SELECT __relation__2.p_custkey AS p_custkey, \ sum(CAST(__relation__2.o_totalprice AS BIGINT)) AS totalcost FROM (SELECT __relation__1.c_custkey, orders.o_custkey, \ orders.o_totalprice, __relation__1.p_custkey FROM (SELECT __source.o_custkey AS o_custkey, __source.o_totalprice AS o_totalprice \ @@ -627,8 +628,8 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT \"Customer\".\"Custkey\", \"Customer\".\"Name\" FROM \ + assert_snapshot!(actual, + @"SELECT \"Customer\".\"Custkey\", \"Customer\".\"Name\" FROM \ (SELECT \"Customer\".\"Custkey\", \"Customer\".\"Name\" FROM \ (SELECT __source.\"Custkey\" AS \"Custkey\", __source.\"Name\" AS \"Name\" FROM datafusion.\"public\".customer AS __source) AS \"Customer\") AS \"Customer\""); Ok(()) @@ -667,7 +668,7 @@ mod test { r#"select add_two("Custkey") from "Customer""#, ) .await?; - assert_eq!(actual, "SELECT add_two(\"Customer\".\"Custkey\") FROM (SELECT \"Customer\".\"Custkey\" \ + assert_snapshot!(actual, @"SELECT add_two(\"Customer\".\"Custkey\") FROM (SELECT \"Customer\".\"Custkey\" \ FROM (SELECT __source.\"Custkey\" AS \"Custkey\" FROM datafusion.\"public\".customer AS __source) AS \"Customer\") AS \"Customer\""); let actual = transform_sql_with_ctx( @@ -678,7 +679,7 @@ mod test { r#"select median("Custkey") from "CTest"."STest"."Customer" group by "Name""#, ) .await?; - assert_eq!(actual, "SELECT median(\"Customer\".\"Custkey\") FROM (SELECT \"Customer\".\"Custkey\", \"Customer\".\"Name\" \ + assert_snapshot!(actual, @"SELECT median(\"Customer\".\"Custkey\") FROM (SELECT \"Customer\".\"Custkey\", \"Customer\".\"Name\" \ FROM (SELECT __source.\"Custkey\" AS \"Custkey\", __source.\"Name\" AS \"Name\" FROM datafusion.\"public\".customer AS __source) AS \"Customer\") AS \"Customer\" \ GROUP BY \"Customer\".\"Name\""); @@ -738,11 +739,11 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT artist.\"名字\", artist.name_append, artist.\"group\", artist.subscribe, artist.subscribe_plus FROM \ - (SELECT artist.\"group\", artist.name_append, artist.subscribe, artist.subscribe_plus, artist.\"名字\" FROM \ - (SELECT __source.\"名字\" AS \"名字\", __source.\"名字\" || __source.\"名字\" AS name_append, __source.\"組別\" AS \"group\", CAST(__source.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus, __source.\"訂閱數\" AS subscribe FROM artist AS __source) AS artist) AS artist" -); + assert_snapshot!(actual, + @"SELECT artist.\"名字\", artist.name_append, artist.\"group\", artist.subscribe, artist.subscribe_plus FROM \ + (SELECT artist.\"group\", artist.name_append, artist.subscribe, artist.subscribe_plus, artist.\"名字\" FROM \ + (SELECT __source.\"名字\" AS \"名字\", __source.\"名字\" || __source.\"名字\" AS name_append, __source.\"組別\" AS \"group\", CAST(__source.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus, __source.\"訂閱數\" AS subscribe FROM artist AS __source) AS artist) AS artist" + ); ctx.sql(&actual).await?.show().await?; let sql = r#"select group from wren.test.artist"#; @@ -754,8 +755,8 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT artist.\"group\" FROM (SELECT artist.\"group\" FROM (SELECT __source.\"組別\" AS \"group\" FROM artist AS __source) AS artist) AS artist"); + assert_snapshot!(actual, + @"SELECT artist.\"group\" FROM (SELECT artist.\"group\" FROM (SELECT __source.\"組別\" AS \"group\" FROM artist AS __source) AS artist) AS artist"); ctx.sql(&actual).await?.show().await?; let sql = r#"select subscribe_plus from wren.test.artist"#; @@ -767,8 +768,8 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT artist.subscribe_plus FROM (SELECT artist.subscribe_plus FROM (SELECT CAST(__source.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus FROM artist AS __source) AS artist) AS artist"); + assert_snapshot!(actual, + @"SELECT artist.subscribe_plus FROM (SELECT artist.subscribe_plus FROM (SELECT CAST(__source.\"訂閱數\" AS BIGINT) + 1 AS subscribe_plus FROM artist AS __source) AS artist) AS artist"); ctx.sql(&actual).await?.show().await } @@ -807,9 +808,9 @@ mod test { ) .await .map_err(|e| { - assert_eq!( + assert_snapshot!( e.to_string(), - "ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"." + @"ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"." ) }); @@ -823,9 +824,9 @@ mod test { ) .await .map_err(|e| { - assert_eq!( + assert_snapshot!( e.to_string(), - "ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"." + @"ModelAnalyzeRule\ncaused by\nSchema error: No field named \"名字\"." ) }); Ok(()) @@ -861,8 +862,8 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT artist.\"串接名字\" FROM (SELECT artist.\"串接名字\" FROM (SELECT __source.\"名字\" || __source.\"名字\" AS \"串接名字\" FROM artist AS __source) AS artist) AS artist"); + assert_snapshot!(actual, + @"SELECT artist.\"串接名字\" FROM (SELECT artist.\"串接名字\" FROM (SELECT __source.\"名字\" || __source.\"名字\" AS \"串接名字\" FROM artist AS __source) AS artist) AS artist"); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( &SessionContext::new(), @@ -872,8 +873,8 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT artist.\"串接名字\" FROM (SELECT artist.\"串接名字\" FROM (SELECT __source.\"名字\" || __source.\"名字\" AS \"串接名字\" FROM artist AS __source) AS artist) AS artist"); + assert_snapshot!(actual, + @"SELECT artist.\"串接名字\" FROM (SELECT artist.\"串接名字\" FROM (SELECT __source.\"名字\" || __source.\"名字\" AS \"串接名字\" FROM artist AS __source) AS artist) AS artist"); let sql = r#"select "名字" from wren.test.artist"#; let _ = transform_sql_with_ctx( @@ -884,9 +885,9 @@ mod test { sql, ) .await.map_err(|e| { - assert_eq!( + assert_snapshot!( e.to_string(), - "Schema error: No field named \"名字\". Valid fields are wren.test.artist.\"串接名字\"." + @"Schema error: No field named \"名字\". Valid fields are wren.test.artist.\"串接名字\"." ) }); Ok(()) @@ -903,7 +904,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT current_date()"); + assert_snapshot!(actual, @"SELECT current_date()"); Ok(()) } @@ -933,8 +934,8 @@ mod test { sql, ) .await?; - assert_eq!(actual, - "SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \ + assert_snapshot!(actual, + @"SELECT CAST(current_date() AS TIMESTAMP) > artist.\"出道時間\" FROM \ (SELECT artist.\"出道時間\" FROM (SELECT __source.\"出道時間\" AS \"出道時間\" FROM artist AS __source) AS artist) AS artist"); Ok(()) } @@ -955,7 +956,7 @@ mod test { .await?; // TODO: BigQuery doesn't support the alias include invalid characters (e.g. `*`, `()`). // We should remove the invalid characters for the alias. - assert_eq!(actual, "SELECT count(1) AS \"count(*)\" FROM (SELECT 1)"); + assert_snapshot!(actual, @"SELECT count(1) AS \"count(*)\" FROM (SELECT 1)"); Ok(()) } @@ -993,7 +994,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT INTERVAL 1 DAY"); + assert_snapshot!(actual, @"SELECT INTERVAL 1 DAY"); let sql = "SELECT INTERVAL '1 YEAR 1 MONTH'"; let actual = transform_sql_with_ctx( @@ -1004,7 +1005,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT INTERVAL 13 MONTH"); + assert_snapshot!(actual, @"SELECT INTERVAL 13 MONTH"); let sql = "SELECT INTERVAL '1' YEAR + INTERVAL '2' MONTH + INTERVAL '3' DAY"; let actual = transform_sql_with_ctx( @@ -1015,9 +1016,9 @@ mod test { sql, ) .await?; - assert_eq!( + assert_snapshot!( actual, - "SELECT INTERVAL 12 MONTH + INTERVAL 2 MONTH + INTERVAL 3 DAY" + @"SELECT INTERVAL 12 MONTH + INTERVAL 2 MONTH + INTERVAL 3 DAY" ); Ok(()) } @@ -1030,7 +1031,7 @@ mod test { let sql = "select * from unnest([1, 2, 3])"; let actual = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; - assert_eq!(actual, "SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM (SELECT UNNEST([1, 2, 3]) AS \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\")"); + assert_snapshot!(actual, @"SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM (SELECT UNNEST([1, 2, 3]) AS \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\")"); let manifest = ManifestBuilder::new() .data_source(DataSource::BigQuery) @@ -1039,7 +1040,7 @@ mod test { let sql = "select * from unnest([1, 2, 3])"; let actual = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; - assert_eq!(actual, "SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM UNNEST([1, 2, 3])"); + assert_snapshot!(actual, @"SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM UNNEST([1, 2, 3])"); Ok(()) } @@ -1056,7 +1057,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00 +08:00\"\")\""); + assert_snapshot!(actual, @"SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00 +08:00\"\")\""); let sql = "select timestamp '2011-01-01 18:00:00 Asia/Taipei'"; let actual = transform_sql_with_ctx( @@ -1067,7 +1068,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00 Asia/Taipei\"\")\""); + assert_snapshot!(actual, @"SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00 Asia/Taipei\"\")\""); Ok(()) } @@ -1088,7 +1089,7 @@ mod test { ) .await?; // TIMESTAMP doesn't have timezone, so the timezone will be ignored - assert_eq!(actual, "SELECT CAST('2011-01-01 18:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); + assert_snapshot!(actual, @"SELECT CAST('2011-01-01 18:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); let sql = "select timestamp with time zone '2011-01-01 18:00:00'"; let actual = transform_sql_with_ctx( @@ -1100,7 +1101,7 @@ mod test { ) .await?; // TIMESTAMP WITH TIME ZONE will be converted to the session timezone - assert_eq!(actual, "SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); + assert_snapshot!(actual, @"SELECT CAST('2011-01-01 10:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2011-01-01 18:00:00\"\")\""); let mut config = ConfigOptions::new(); config.execution.time_zone = Some("America/New_York".to_string()); @@ -1117,7 +1118,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT CAST('2024-01-15 23:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2024-01-15 18:00:00\"\")\""); + assert_snapshot!(actual, @"SELECT CAST('2024-01-15 23:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2024-01-15 18:00:00\"\")\""); // TIMESTAMP WITH TIME ZONE will be converted to the session timezone without daylight saving (UTC -4) let sql = "select timestamp with time zone '2024-07-15 18:00:00'"; @@ -1129,7 +1130,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT CAST('2024-07-15 22:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2024-07-15 18:00:00\"\")\""); + assert_snapshot!(actual, @"SELECT CAST('2024-07-15 22:00:00' AS TIMESTAMP) AS \"Utf8(\"\"2024-07-15 18:00:00\"\")\""); Ok(()) } @@ -1167,10 +1168,10 @@ mod test { sql, ) .await?; - assert_eq!(actual, + assert_snapshot!(actual, // TODO: BigQuery doesn't support the alias include invalid characters (e.g. `*`, `()`). // We should remove the invalid characters for the alias. - "SELECT count(1) AS \"count(*)\" FROM (SELECT artist.cast_timestamptz FROM \ + @"SELECT count(1) AS \"count(*)\" FROM (SELECT artist.cast_timestamptz FROM \ (SELECT CAST(__source.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamptz \ FROM artist AS __source) AS artist) AS artist WHERE CAST(artist.cast_timestamptz AS TIMESTAMP) > CAST('2011-01-01 21:00:00' AS TIMESTAMP)"); Ok(()) @@ -1213,14 +1214,13 @@ mod test { .await?; let sql = r#"select arrow_typeof(timestamp_col), arrow_typeof(timestamptz_col) from wren.test.timestamp_table limit 1"#; let result = ctx.sql(sql).await?.collect().await?; - let expected = vec![ - "+---------------------------------------------+-----------------------------------------------+", - "| arrow_typeof(timestamp_table.timestamp_col) | arrow_typeof(timestamp_table.timestamptz_col) |", - "+---------------------------------------------+-----------------------------------------------+", - "| Timestamp(Nanosecond, None) | Timestamp(Nanosecond, Some(\"UTC\")) |", - "+---------------------------------------------+-----------------------------------------------+", - ]; - assert_batches_eq!(&expected, &result); + assert_snapshot!(batches_to_string(&result), @r#" + +---------------------------------------------+-----------------------------------------------+ + | arrow_typeof(timestamp_table.timestamp_col) | arrow_typeof(timestamp_table.timestamptz_col) | + +---------------------------------------------+-----------------------------------------------+ + | Timestamp(Nanosecond, None) | Timestamp(Nanosecond, Some("UTC")) | + +---------------------------------------------+-----------------------------------------------+ + "#); Ok(()) } @@ -1332,7 +1332,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT list_table.list_col[1] FROM (SELECT list_table.list_col FROM \ + assert_snapshot!(actual, @"SELECT list_table.list_col[1] FROM (SELECT list_table.list_col FROM \ (SELECT __source.list_col AS list_col FROM list_table AS __source) AS list_table) AS list_table"); Ok(()) } @@ -1373,9 +1373,9 @@ mod test { sql, ) .await?; - assert_eq!( + assert_snapshot!( actual, - "SELECT struct_table.struct_col.float_field FROM \ + @"SELECT struct_table.struct_col.float_field FROM \ (SELECT struct_table.struct_col FROM (SELECT __source.struct_col AS struct_col \ FROM struct_table AS __source) AS struct_table) AS struct_table" ); @@ -1389,7 +1389,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT struct_table.struct_array_col[1].float_field FROM \ + assert_snapshot!(actual, @"SELECT struct_table.struct_array_col[1].float_field FROM \ (SELECT struct_table.struct_array_col FROM (SELECT __source.struct_array_col AS struct_array_col \ FROM struct_table AS __source) AS struct_table) AS struct_table"); @@ -1403,7 +1403,7 @@ mod test { sql, ) .await?; - assert_eq!(actual, "SELECT {float_field: 1.0, time_field: CAST('2021-01-01 00:00:00' AS TIMESTAMP)}"); + assert_snapshot!(actual, @"SELECT {float_field: 1.0, time_field: CAST('2021-01-01 00:00:00' AS TIMESTAMP)}"); let manifest = ManifestBuilder::new() .catalog("wren") @@ -1420,9 +1420,9 @@ mod test { let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) .await .map_err(|e| { - assert_eq!( + assert_snapshot!( e.to_string(), - "Execution error: The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got Utf8" + @"Execution error: The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got Utf8" ) }); Ok(()) @@ -1442,7 +1442,7 @@ mod test { sql, ) .await?; - assert_eq!(result, "SELECT CAST(CAST('2021-01-01 00:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE) = \ + assert_snapshot!(result, @"SELECT CAST(CAST('2021-01-01 00:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE) = \ CAST(CAST('2021-01-01 00:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE)"); Ok(()) } @@ -1462,9 +1462,9 @@ mod test { sql, ) .await?; - assert_eq!( + assert_snapshot!( result, - "SELECT x, y FROM (SELECT 1 AS x, 'a' AS y \ + @"SELECT x, y FROM (SELECT 1 AS x, 'a' AS y \ UNION ALL SELECT 1 AS x, 'b' AS y \ UNION ALL SELECT 2 AS x, 'a' AS y \ UNION ALL SELECT 2 AS x, 'c' AS y)" @@ -1477,11 +1477,11 @@ mod test { let manifest = ManifestBuilder::default().data_source(MySQL).build(); let mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); let ctx = SessionContext::new(); - let expected = "SELECT trim(' abc')"; + let sql = "SELECT trim(' abc')"; let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], HashMap::new(), expected) + transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], HashMap::new(), sql) .await?; - assert_eq!(actual, expected); + assert_snapshot!(actual, @"SELECT trim(' abc')"); Ok(()) } @@ -1509,9 +1509,9 @@ mod test { sql, ) .await?; - assert_eq!( + assert_snapshot!( result, - "SELECT customer.c_custkey, count(DISTINCT customer.c_name) FROM \ + @"SELECT customer.c_custkey, count(DISTINCT customer.c_name) FROM \ (SELECT customer.c_custkey, customer.c_name FROM \ (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer \ GROUP BY customer.c_custkey" @@ -1543,9 +1543,9 @@ mod test { sql, ) .await?; - assert_eq!( + assert_snapshot!( result, - "SELECT customer.c_custkey, (SELECT customer.c_name FROM (SELECT customer.c_custkey, customer.c_name \ + @"SELECT customer.c_custkey, (SELECT customer.c_name FROM (SELECT customer.c_custkey, customer.c_name \ FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer \ WHERE customer.c_custkey = 1) FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" ); @@ -1576,9 +1576,9 @@ mod test { sql, ) .await?; - assert_eq!( + assert_snapshot!( result, - "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ + @"SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer \ WHERE customer.c_custkey = 1" ); @@ -1734,4 +1734,12 @@ mod test { ]) .unwrap() } + + fn batches_to_string(batches: &[RecordBatch]) -> String { + let actual = pretty_format_batches_with_options(batches, &DEFAULT_FORMAT_OPTIONS) + .unwrap() + .to_string(); + + actual.trim().to_string() + } } From fc5b970301196c315619e77db080e781013259ce Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 22 Apr 2025 23:42:17 +0800 Subject: [PATCH 07/30] add test and fix filter position --- wren-core-py/Cargo.lock | 36 ++++ .../logical_plan/analyze/model_generation.rs | 11 +- wren-core/core/src/mdl/mod.rs | 175 +++++++++++++++++- 3 files changed, 217 insertions(+), 5 deletions(-) diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index e94697b30..648712b41 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -594,6 +594,18 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys", +] + [[package]] name = "const-random" version = "0.1.18" @@ -1229,6 +1241,12 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "env_filter" version = "0.1.3" @@ -1714,6 +1732,17 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" +[[package]] +name = "insta" +version = "1.43.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371" +dependencies = [ + "console", + "once_cell", + "similar", +] + [[package]] name = "integer-encoding" version = "3.0.4" @@ -2672,6 +2701,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "siphasher" version = "1.0.1" @@ -3348,6 +3383,7 @@ dependencies = [ "csv", "datafusion", "env_logger", + "insta", "log", "parking_lot", "petgraph 0.7.1", diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index 9b08b4bec..093e1dc15 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -75,6 +75,7 @@ impl ModelGenerationRule { } else { model_plan.required_exprs.clone() }; + let projections = eliminate_ambiguous_columns(projections); let mut builder = if let Some(plan) = source_plan { LogicalPlanBuilder::from(plan) @@ -82,10 +83,6 @@ impl ModelGenerationRule { return plan_err!("Failed to generate source plan"); }; - if !model_plan.required_exprs.is_empty() { - builder = builder.project(projections)? - } - let filters: Vec> = model_plan .model .row_level_access_controls() @@ -110,10 +107,16 @@ impl ModelGenerationRule { }) .flatten(); + // follow the logical plan of DataFusion. + // the filter should be placed top on the relation. if let Some(filter) = rls_filter { builder = builder.filter(filter)? } + if !model_plan.required_exprs.is_empty() { + builder = builder.project(projections)? + } + // calculated field scope Ok(Transformed::yes(builder.build()?)) } else if let Some(model_plan) = diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 479dd7e83..1997aba46 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -456,7 +456,7 @@ mod test { use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion::sql::unparser::plan_to_sql; use insta::assert_snapshot; - use wren_core_base::mdl::DataSource; + use wren_core_base::mdl::{DataSource, SessionProperty}; #[test] fn test_sync_transform() -> Result<()> { @@ -1671,6 +1671,169 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_rlac_with_requried_properties() -> Result<()> { + env_logger::init(); + let ctx = SessionContext::new(); + + // test required property + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![SessionProperty::new_required("session_nation")], + "c_nationkey = @session_nation", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = "SELECT * FROM customer"; + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?; + assert_snapshot!( + result, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" + ); + + match transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await + { + Err(e) => { + assert_snapshot!( + e.to_string(), + @r" + ModelAnalyzeRule + caused by + Error during planning: Row level access control property session_nation is required, but not found in headers + " + ) + } + _ => panic!("Expected error"), + } + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![SessionProperty::new_required("session_nation")], + "c_nationkey = @session_nation", + ) + .add_row_level_access_control( + "name", + vec![SessionProperty::new_required("session_user")], + "c_name = @session_user", + ) + .build(), + ) + .model( + ModelBuilder::new("orders") + .table_reference("orders") + .column(ColumnBuilder::new("o_orderkey", "int").build()) + .column(ColumnBuilder::new("o_custkey", "int").build()) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = "SELECT * FROM customer"; + let headers = build_headers(&[ + ("session_nation".to_string(), Some("1".to_string())), + ("session_user".to_string(), Some("'Gura'".to_string())), + ]); + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + headers.clone(), + sql, + ) + .await?; + assert_snapshot!( + result, + @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer" + ); + + let sql = "SELECT * FROM customer WHERE c_custkey = 1"; + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?; + assert_snapshot!( + result, + @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer WHERE customer.c_custkey = 1" + ); + + // test other model won't be affected + let sql = "SELECT o_orderkey FROM orders"; + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; + assert_snapshot!( + result, + @"SELECT orders.o_orderkey FROM (SELECT orders.o_orderkey FROM (SELECT __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders" + ); + + let sql = "SELECT o_orderkey FROM customer JOIN orders ON customer.c_custkey = orders.o_custkey"; + let headers = build_headers(&[ + ("session_nation".to_string(), Some("1".to_string())), + ("session_user".to_string(), Some("'Gura'".to_string())), + ]); + let result = + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?; + assert_snapshot!( + result, + @"SELECT orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer JOIN (SELECT orders.o_custkey, orders.o_orderkey FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders ON customer.c_custkey = orders.o_custkey" + ); + + // test property is required + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let sql = "SELECT * FROM customer"; + match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await + { + Err(e) => { + assert_snapshot!( + e.to_string(), + @r" + ModelAnalyzeRule + caused by + Error during planning: Row level access control property session_user is required, but not found in headers + " + ) + } + _ => panic!("Expected error"), + } + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); @@ -1742,4 +1905,14 @@ mod test { actual.trim().to_string() } + + fn build_headers( + field: &[(String, Option)], + ) -> HashMap> { + let mut headers = HashMap::new(); + for (key, value) in field { + headers.insert(key.clone(), value.clone()); + } + headers + } } From c3e21880122886e3edc141855b03b3eb7b036fa2 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 23 Apr 2025 00:19:18 +0800 Subject: [PATCH 08/30] add optional property test --- wren-core/core/src/mdl/mod.rs | 208 +++++++++++++++++++++++++++++----- 1 file changed, 177 insertions(+), 31 deletions(-) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 1997aba46..0f2326425 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1673,7 +1673,6 @@ mod test { #[tokio::test] async fn test_rlac_with_requried_properties() -> Result<()> { - env_logger::init(); let ctx = SessionContext::new(); // test required property @@ -1697,11 +1696,8 @@ mod test { let sql = "SELECT * FROM customer"; let headers = build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) - .await?; assert_snapshot!( - result, + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" ); @@ -1762,40 +1758,21 @@ mod test { ("session_nation".to_string(), Some("1".to_string())), ("session_user".to_string(), Some("'Gura'".to_string())), ]); - let result = transform_sql_with_ctx( - &ctx, - Arc::clone(&analyzed_mdl), - &[], - headers.clone(), - sql, - ) - .await?; assert_snapshot!( - result, + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql,).await?, @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer" ); let sql = "SELECT * FROM customer WHERE c_custkey = 1"; - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) - .await?; assert_snapshot!( - result, + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer WHERE customer.c_custkey = 1" ); // test other model won't be affected let sql = "SELECT o_orderkey FROM orders"; - let result = transform_sql_with_ctx( - &ctx, - Arc::clone(&analyzed_mdl), - &[], - HashMap::new(), - sql, - ) - .await?; assert_snapshot!( - result, + transform_sql_with_ctx(&ctx,Arc::clone(&analyzed_mdl),&[],HashMap::new(),sql).await?, @"SELECT orders.o_orderkey FROM (SELECT orders.o_orderkey FROM (SELECT __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders" ); @@ -1804,11 +1781,8 @@ mod test { ("session_nation".to_string(), Some("1".to_string())), ("session_user".to_string(), Some("'Gura'".to_string())), ]); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) - .await?; assert_snapshot!( - result, + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer JOIN (SELECT orders.o_custkey, orders.o_orderkey FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders ON customer.c_custkey = orders.o_custkey" ); @@ -1831,6 +1805,178 @@ mod test { } _ => panic!("Expected error"), } + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![ + SessionProperty::new_required("session_nation"), + SessionProperty::new_optional("session_user", None), + ], + "c_nationkey = @session_nation AND c_name = @session_user", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = "SELECT * FROM customer"; + + let headers = build_headers(&[ + ("session_nation".to_string(), Some("1".to_string())), + ("session_user".to_string(), Some("'Peko'".to_string())), + ]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Peko') AS customer" + ); + + // expect ignore the rule because session_user is optional without default value + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" + ); + // expect error because session_user is required + let headers = + build_headers(&[("session_user".to_string(), Some("'Peko'".to_string()))]); + match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await + { + Err(e) => { + assert_snapshot!( + e.to_string(), + @r" + ModelAnalyzeRule + caused by + Error during planning: Row level access control property session_nation is required, but not found in headers + " + ) + } + _ => panic!("Expected error"), + } + Ok(()) + } + + #[tokio::test] + async fn test_rlac_with_optional_properties() -> Result<()> { + let ctx = SessionContext::new(); + + // test required property + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![SessionProperty::new_optional( + "session_nation", + Some("3".to_string()), + )], + "c_nationkey = @session_nation", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = "SELECT * FROM customer"; + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" + ); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 3) AS customer" + ); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![SessionProperty::new_optional("session_nation", None)], + "c_nationkey = @session_nation", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" + ); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" + ); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![ + SessionProperty::new_optional("session_nation", None), + SessionProperty::new_optional( + "session_user", + Some("'Gura'".to_string()), + ), + ], + "c_nationkey = @session_nation and c_name = @session_user", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer" + ); + // the rule is expected to be skipped because the optional property is None without default value + let headers = + build_headers(&[("session_user".to_string(), Some("'Peko'".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" + ); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" + ); Ok(()) } From f3517f0683c91ea78cbfda07e7fc8c85f1ca3d46 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 23 Apr 2025 12:19:58 +0800 Subject: [PATCH 09/30] support for calcaulted field as the condition --- .../logical_plan/analyze/model_generation.rs | 12 +- wren-core/core/src/mdl/mod.rs | 145 ++++++++++++++++-- 2 files changed, 142 insertions(+), 15 deletions(-) diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index 093e1dc15..d6f92219a 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -107,16 +107,18 @@ impl ModelGenerationRule { }) .flatten(); - // follow the logical plan of DataFusion. - // the filter should be placed top on the relation. - if let Some(filter) = rls_filter { - builder = builder.filter(filter)? - } if !model_plan.required_exprs.is_empty() { builder = builder.project(projections)? } + // apply the rule for row level access control + // The filter should be on on the top of the model plan + // and the model plan should be another subquery alias + if let Some(filter) = rls_filter { + builder = builder.alias(model_plan.plan_name())?.filter(filter)? + } + // calculated field scope Ok(Transformed::yes(builder.build()?)) } else if let Some(model_plan) = diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 0f2326425..1514eaac6 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -456,7 +456,9 @@ mod test { use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion::sql::unparser::plan_to_sql; use insta::assert_snapshot; - use wren_core_base::mdl::{DataSource, SessionProperty}; + use wren_core_base::mdl::{ + DataSource, JoinType, RelationshipBuilder, SessionProperty, + }; #[test] fn test_sync_transform() -> Result<()> { @@ -1698,7 +1700,7 @@ mod test { build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, - @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" ); match transform_sql_with_ctx( @@ -1760,13 +1762,13 @@ mod test { ]); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql,).await?, - @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer" + @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); let sql = "SELECT * FROM customer WHERE c_custkey = 1"; assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, - @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer WHERE customer.c_custkey = 1" + @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_custkey = 1 AND customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); // test other model won't be affected @@ -1783,7 +1785,7 @@ mod test { ]); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, - @"SELECT orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer JOIN (SELECT orders.o_custkey, orders.o_orderkey FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders ON customer.c_custkey = orders.o_custkey" + @"SELECT orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer JOIN (SELECT orders.o_custkey, orders.o_orderkey FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders ON customer.c_custkey = orders.o_custkey WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); // test property is required @@ -1834,7 +1836,7 @@ mod test { ]); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, - @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Peko') AS customer" + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Peko'" ); // expect ignore the rule because session_user is optional without default value @@ -1896,12 +1898,12 @@ mod test { assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, - @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" ); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) .await?, - @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 3) AS customer" + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 3" ); let manifest = ManifestBuilder::new() @@ -1926,7 +1928,7 @@ mod test { assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, - @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1) AS customer" + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" ); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) @@ -1962,7 +1964,7 @@ mod test { assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, - @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura') AS customer" + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); // the rule is expected to be skipped because the optional property is None without default value let headers = @@ -1980,6 +1982,129 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_rlac_on_calculated_field() -> Result<()> { + let ctx = SessionContext::new(); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .primary_key("c_custkey") + .build(), + ) + .model( + ModelBuilder::new("orders") + .table_reference("orders") + .column(ColumnBuilder::new("o_orderkey", "int").build()) + .column(ColumnBuilder::new("o_custkey", "int").build()) + .column( + ColumnBuilder::new("customer", "customer") + .relationship("customer_orders") + .build(), + ) + .column( + ColumnBuilder::new("customer_name", "string") + .calculated(true) + .expression("customer.c_name") + .build(), + ) + .primary_key("o_orderkey") + .add_row_level_access_control( + "customer name", + vec![SessionProperty::new_required("session_user")], + "customer_name = @session_user", + ) + .build(), + ) + .relationship( + RelationshipBuilder::new("customer_orders") + .model("customer") + .model("orders") + .join_type(JoinType::OneToMany) + .condition("customer.c_custkey = orders.o_custkey") + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let headers = + build_headers(&[("session_user".to_string(), Some("'Gura'".to_string()))]); + let sql = "SELECT * FROM orders"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql).await?, + @"SELECT orders.o_orderkey, orders.o_custkey, orders.customer_name FROM (SELECT __relation__1.c_name AS customer_name, __relation__1.o_custkey, __relation__1.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, orders.o_custkey, orders.o_orderkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer RIGHT JOIN (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders ON customer.c_custkey = orders.o_custkey) AS __relation__1) AS orders WHERE orders.customer_name = 'Gura'" + ); + + let sql = "SELECT * FROM orders where o_orderkey > 10"; + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT orders.o_orderkey, orders.o_custkey, orders.customer_name FROM (SELECT __relation__1.c_name AS customer_name, __relation__1.o_custkey, __relation__1.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, orders.o_custkey, orders.o_orderkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer RIGHT JOIN (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders ON customer.c_custkey = orders.o_custkey) AS __relation__1) AS orders WHERE orders.o_orderkey > 10 AND orders.customer_name = 'Gura'" + ); + + // TODO: the rlac rule should be applied for the model used by the calculated field + // both to_one or to_many relationship should be supported + // + // let manifest = ManifestBuilder::new() + // .catalog("wren") + // .schema("test") + // .model( + // ModelBuilder::new("customer") + // .table_reference("customer") + // .column(ColumnBuilder::new("c_custkey", "int").build()) + // .column(ColumnBuilder::new("c_nationkey", "int").build()) + // .column(ColumnBuilder::new("c_name", "string").build()) + // .primary_key("c_custkey") + // .add_row_level_access_control( + // "nation rule", + // vec![SessionProperty::new_optional("session_nation", None)], + // "c_nationkey = @session_nation", + // ) + // .build(), + // ) + // .model( + // ModelBuilder::new("orders") + // .table_reference("orders") + // .column(ColumnBuilder::new("o_orderkey", "int").build()) + // .column(ColumnBuilder::new("o_custkey", "int").build()) + // .column( + // ColumnBuilder::new("customer", "customer") + // .relationship("customer_orders") + // .build(), + // ) + // .column( + // ColumnBuilder::new("customer_name", "string") + // .calculated(true) + // .expression("customer.c_name") + // .build(), + // ) + // .primary_key("o_orderkey") + // .build(), + // ) + // .relationship( + // RelationshipBuilder::new("customer_orders") + // .model("customer") + // .model("orders") + // .join_type(JoinType::OneToMany) + // .condition("customer.c_custkey = orders.o_custkey") + // .build(), + // ) + // .build(); + // let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + // let headers = + // build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + // let sql = "SELECT customer_name FROM orders"; + // assert_snapshot!( + // transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + // @"SELECT orders.customer_name FROM (SELECT __relation__1.c_name AS customer_name FROM (SELECT customer.c_custkey, customer.c_name, orders.o_custkey, orders.o_orderkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer RIGHT JOIN (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders ON customer.c_custkey = orders.o_custkey) AS __relation__1) AS orders" + // ); + Ok(()) + } + /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); From 5378278cd820ce43e8feebdf233b6f197f0d383c Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 23 Apr 2025 12:20:40 +0800 Subject: [PATCH 10/30] fmt --- wren-core/core/src/logical_plan/analyze/model_generation.rs | 1 - wren-core/core/src/mdl/mod.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/wren-core/core/src/logical_plan/analyze/model_generation.rs b/wren-core/core/src/logical_plan/analyze/model_generation.rs index d6f92219a..dbbd6d4ed 100644 --- a/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -107,7 +107,6 @@ impl ModelGenerationRule { }) .flatten(); - if !model_plan.required_exprs.is_empty() { builder = builder.project(projections)? } diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 1514eaac6..de6dde868 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -2048,7 +2048,7 @@ mod test { // TODO: the rlac rule should be applied for the model used by the calculated field // both to_one or to_many relationship should be supported - // + // // let manifest = ManifestBuilder::new() // .catalog("wren") // .schema("test") From 5978cac99d08e40a4e44e4472079467e96e53cba Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 14:54:20 +0800 Subject: [PATCH 11/30] expose to python binding --- wren-core-py/Cargo.lock | 57 +++++++++++++++++++++++- wren-core-py/src/context.rs | 22 ++++++++- wren-core-py/src/manifest.rs | 2 + wren-core-py/tests/test_modeling_core.py | 25 +++++++++++ 4 files changed, 102 insertions(+), 4 deletions(-) diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index 648712b41..46e6872bc 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -606,6 +606,18 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys", +] + [[package]] name = "const-random" version = "0.1.18" @@ -1247,6 +1259,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "env_filter" version = "0.1.3" @@ -1734,12 +1752,14 @@ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "insta" -version = "1.43.1" +version = "1.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "154934ea70c58054b556dd430b99a98c2a7ff5309ac9891597e339b5c28f4371" +checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" dependencies = [ "console", + "linked-hash-map", "once_cell", + "pin-project", "similar", ] @@ -1905,6 +1925,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.9.3" @@ -2262,6 +2288,26 @@ dependencies = [ "siphasher", ] +[[package]] +name = "pin-project" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2707,6 +2753,12 @@ version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "siphasher" version = "1.0.1" @@ -3384,6 +3436,7 @@ dependencies = [ "datafusion", "env_logger", "insta", + "insta", "log", "parking_lot", "petgraph 0.7.1", diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index a91d39e37..62b638eeb 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -20,6 +20,7 @@ use crate::manifest::to_manifest; use crate::remote_functions::PyRemoteFunction; use log::debug; use pyo3::{pyclass, pymethods, PyErr, PyResult}; +use std::collections::HashMap; use std::hash::Hash; use std::ops::ControlFlow; use std::str::FromStr; @@ -123,8 +124,14 @@ impl PySessionContext { let analyzed_mdl = Arc::new(analyzed_mdl); + // the headers won't be used in the context. Provide an empty map. let ctx = runtime - .block_on(create_ctx_with_mdl(&ctx, Arc::clone(&analyzed_mdl), false)) + .block_on(create_ctx_with_mdl( + &ctx, + Arc::clone(&analyzed_mdl), + Arc::new(HashMap::new()), + false, + )) .map_err(CoreError::from)?; Ok(Self { @@ -135,12 +142,23 @@ impl PySessionContext { } /// Transform the given Wren SQL to the equivalent Planned SQL. - pub fn transform_sql(&self, sql: &str) -> PyResult { + #[pyo3(signature = (sql=None, properties=None))] + pub fn transform_sql( + &self, + sql: Option<&str>, + properties: Option>>, + ) -> PyResult { + let Some(sql) = sql else { + return Err(CoreError::new("SQL is required").into()); + }; self.runtime .block_on(mdl::transform_sql_with_ctx( &self.ctx, Arc::clone(&self.mdl), + // the ctx has been initialized when PySessionContext is created + // so we can pass the empty array here &[], + properties.unwrap_or_default(), sql, )) .map_err(|e| PyErr::from(CoreError::from(e))) diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index 9749b1303..b1fe2e975 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -43,6 +43,7 @@ mod tests { primary_key: None, cached: false, refresh_time: None, + row_level_access_controls: vec![], }), Arc::from(Model { name: "model_2".to_string(), @@ -53,6 +54,7 @@ mod tests { primary_key: None, cached: false, refresh_time: None, + row_level_access_controls: vec![], }), ], relationships: vec![], diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index cedec501e..c6f116367 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -25,6 +25,18 @@ {"name": "c_name", "type": "varchar"}, {"name": "orders", "type": "orders", "relationship": "orders_customer"}, ], + "rowLevelAccessControls": [ + { + "name": "customer_access", + "requiredProperties": [ + { + "name": "session_user", + "required": False, + } + ], + "condition": "c_name = @session_user", + }, + ], "primaryKey": "c_custkey", }, { @@ -273,3 +285,16 @@ def test_limit_pushdown(): session_context.pushdown_limit(sql, 10) == "SELECT * FROM my_catalog.my_schema.customer LIMIT 10 OFFSET 5" ) + + +def test_rlac(): + headers = { + "session_user": "'test_user'", + } + session_context = SessionContext(manifest_str, None) + sql = "SELECT * FROM my_catalog.my_schema.customer" + rewritten_sql = session_context.transform_sql(sql, headers) + assert ( + rewritten_sql + == "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM main.customer AS __source) AS customer) AS customer WHERE customer.c_name = 'test_user'" + ) From c86c1bf494a183a69527a01f8f3b10d669e4cad2 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 15:31:39 +0800 Subject: [PATCH 12/30] ensure the name is case insensitive --- .../logical_plan/analyze/access_control.rs | 34 +++++++++++++++++-- wren-core/core/src/mdl/context.rs | 9 +++++ wren-core/core/src/mdl/mod.rs | 31 +++++++++++++++++ 3 files changed, 71 insertions(+), 3 deletions(-) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index df27fad3d..464131546 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -93,11 +93,11 @@ pub fn build_filter_expression( visit_expressions_mut(&mut expr, |expr| { if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { if value.starts_with("@") { - let property_name = value.trim_start_matches("@").to_string(); + let property_name = value.trim_start_matches("@").to_string().to_lowercase(); let Some(property_value) = properties.get(&property_name).or_else(|| { required_properties .iter() - .filter(|r| r.name == property_name && !r.required) + .filter(|r| r.name.to_lowercase() == property_name && !r.required) .map(|r| &r.default_expr) .next() }) else { @@ -199,7 +199,7 @@ fn is_property_present( property_name: &str, ) -> bool { headers - .get(property_name) + .get(&property_name.to_lowercase()) .map(|v| v.as_ref().is_some_and(|value| !value.is_empty())) .unwrap_or(false) } @@ -512,4 +512,32 @@ mod test { let unparser = Unparser::default().with_pretty(true); unparser.expr_to_sql(expr).map(|sql| sql.to_string()) } + + #[test] + pub fn test_match_case_insensitive() -> Result<()> { + let ctx = SessionContext::new(); + let state = ctx.state_ref(); + let model = ModelBuilder::new("m1") + .column(ColumnBuilder::new("id", "int").build()) + .column(ColumnBuilder::new("name", "varchar").build()) + .build(); + + let headers: Arc>> = Arc::new(build_headers(&[ + ("session_id".to_string(), Some("1".to_string())), + ("session_name".to_string(), Some("'test'".to_string())), + ])); + + let rule = RowLevelAccessControl { + condition: "id = @session_id AND name = @SESSION_NAME".to_string(), + required_properties: vec![ + SessionProperty::new_required("SESSION_ID"), + SessionProperty::new_required("session_name"), + ], + name: "test".to_string(), + }; + + let expr = build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); + Ok(()) + } } diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 5ab4f6027..22d465541 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -68,6 +68,15 @@ pub async fn create_ctx_with_mdl( reset_default_catalog_schema.clone().read().deref().clone(), ); + // ensure all the key in properties is lowercase + let properties = Arc::new(properties + .iter() + .map(|(k, v)| { + let k = k.to_lowercase(); + (k, v.clone()) + }) + .collect::>()); + let new_state = if is_local_runtime { new_state.with_analyzer_rules(analyze_rule_for_local_runtime( Arc::clone(&analyzed_mdl), diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index de6dde868..d6bc0f4e8 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -2105,6 +2105,37 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_rlac_case_insensitive() -> Result<()> { + let ctx = SessionContext::new(); + + // test required property + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![SessionProperty::new_required("session_nation")], + "c_nationkey = @session_nation", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let sql = "SELECT * FROM customer"; + let headers = + build_headers(&[("SESSION_NATION".to_string(), Some("1".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" + ); + Ok(()) + } /// Return a RecordBatch with made up data about customer fn customer() -> RecordBatch { let custkey: ArrayRef = Arc::new(Int64Array::from(vec![1, 2, 3])); From 13c74abbc426b1e8c0dfb2b823c4843c3d35837d Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 16:55:38 +0800 Subject: [PATCH 13/30] expose to the wren engine api --- ibis-server/app/mdl/rewriter.py | 30 ++++++++++-- ibis-server/app/routers/v3/connector.py | 38 +++++++++++---- ibis-server/app/util.py | 8 ++++ .../v3/connector/postgres/test_fallback_v2.py | 13 ++++++ .../v3/connector/postgres/test_query.py | 46 +++++++++++++++++++ 5 files changed, 123 insertions(+), 12 deletions(-) diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index e0182013b..468c72929 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -32,10 +32,12 @@ def __init__( data_source: DataSource = None, java_engine_connector: JavaEngineConnector = None, experiment=False, + properties: dict | None = None, ): self.manifest_str = manifest_str self.data_source = data_source self.experiment = experiment + self.properties = properties if experiment: function_path = get_config().get_remote_function_list_path(data_source) self._rewriter = EmbeddedEngineRewriter(function_path) @@ -54,7 +56,7 @@ async def rewrite(self, sql: str) -> str: self._extract_manifest(self.manifest_str, sql) or self.manifest_str ) logger.debug("Extracted manifest: {}", manifest_str) - planned_sql = await self._rewriter.rewrite(manifest_str, sql) + planned_sql = await self._rewriter.rewrite(manifest_str, sql, self.properties) logger.debug("Planned SQL: {}", planned_sql) dialect_sql = self._transpile(planned_sql) if self.data_source else planned_sql logger.debug("Dialect SQL: {}", dialect_sql) @@ -93,7 +95,9 @@ def __init__(self, java_engine_connector: JavaEngineConnector): self.java_engine_connector = java_engine_connector @tracer.start_as_current_span("external_rewrite", kind=trace.SpanKind.CLIENT) - async def rewrite(self, manifest_str: str, sql: str) -> str: + async def rewrite( + self, manifest_str: str, sql: str, properties: dict | None = None + ) -> str: try: return await self.java_engine_connector.dry_plan(manifest_str, sql) except httpx.ConnectError as e: @@ -113,13 +117,31 @@ def __init__(self, function_path: str): self.function_path = function_path @tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL) - async def rewrite(self, manifest_str: str, sql: str) -> str: + async def rewrite( + self, manifest_str: str, sql: str, properties: dict | None = None + ) -> str: try: session_context = get_session_context(manifest_str, self.function_path) - return await to_thread.run_sync(session_context.transform_sql, sql) + return await to_thread.run_sync( + session_context.transform_sql, + sql, + self.get_session_properties(properties), + ) except Exception as e: raise RewriteError(str(e)) + def get_session_properties(self, properties: dict) -> dict | None: + if properties is None: + return None + # filter the properties which name starts with "x-wren-variables-" + # and remove the prefix "x-wren-variables-" + + return { + k.replace("x-wren-variables-", ""): v + for k, v in properties.items() + if k.startswith("x-wren-variables-") + } + @staticmethod def handle_extract_exception(e: Exception): raise RewriteError(str(e)) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 5562c4326..81416fe0c 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -27,6 +27,7 @@ from app.util import ( append_fallback_context, build_context, + exist_wren_variables_header, pushdown_limit, safe_strtobool, to_json, @@ -72,7 +73,10 @@ async def query( if dry_run: sql = pushdown_limit(dto.sql, limit) rewritten_sql = await Rewriter( - dto.manifest_str, data_source=data_source, experiment=True + dto.manifest_str, + data_source=data_source, + experiment=True, + properties=dict(headers), ).rewrite(sql) connector = Connector(data_source, dto.connection_info) connector.dry_run(rewritten_sql) @@ -145,7 +149,9 @@ async def query( headers.get(X_WREN_FALLBACK_DISABLE) and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) - if is_fallback_disable: + # because the v2 API doesn't support row-level access control, + # we don't fallback to v2 if the header include row-level access control properties. + if is_fallback_disable or exist_wren_variables_header(headers): raise e logger.warning( @@ -176,13 +182,17 @@ async def dry_plan( name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers) ) as span: try: - return await Rewriter(dto.manifest_str, experiment=True).rewrite(dto.sql) + return await Rewriter( + dto.manifest_str, experiment=True, properties=dict(headers) + ).rewrite(dto.sql) except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) - if is_fallback_disable: + # because the v2 API doesn't support row-level access control, + # we don't fallback to v2 if the header include row-level access control properties. + if is_fallback_disable or exist_wren_variables_header(headers): raise e logger.warning( @@ -213,14 +223,19 @@ async def dry_plan_for_data_source( ) as span: try: return await Rewriter( - dto.manifest_str, data_source=data_source, experiment=True + dto.manifest_str, + data_source=data_source, + experiment=True, + properties=dict(headers), ).rewrite(dto.sql) except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) - if is_fallback_disable: + # because the v2 API doesn't support row-level access control, + # we don't fallback to v2 if the header include row-level access control properties. + if is_fallback_disable or exist_wren_variables_header(headers): raise e logger.warning( @@ -254,7 +269,12 @@ async def validate( try: validator = Validator( Connector(data_source, dto.connection_info), - Rewriter(dto.manifest_str, data_source=data_source, experiment=True), + Rewriter( + dto.manifest_str, + data_source=data_source, + experiment=True, + properties=dict(headers), + ), ) await validator.validate(rule_name, dto.parameters, dto.manifest_str) return Response(status_code=204) @@ -263,7 +283,9 @@ async def validate( headers.get(X_WREN_FALLBACK_DISABLE) and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) - if is_fallback_disable: + # because the v2 API doesn't support row-level access control, + # we don't fallback to v2 if the header include row-level access control properties. + if is_fallback_disable or exist_wren_variables_header(headers): raise e logger.warning( diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index b5e2ccb55..c70d224e7 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -147,3 +147,11 @@ def get_fallback_message( def safe_strtobool(val: str) -> bool: return val.lower() in {"1", "true", "yes", "y"} + + +def exist_wren_variables_header( + headers: Header, +) -> bool: + if headers is None: + return False + return any(key.startswith("x-wren-variables-") for key in headers.keys()) diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py index 0f6b89c9b..a37ba363a 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py @@ -383,3 +383,16 @@ async def test_validate(client, manifest_str, connection_info): }, ) assert response.status_code == 422 + + +async def test_query_rlac(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={"x-wren-variables-session_user": "1"}, + ) + assert response.status_code == 422 diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index d65dacafa..3864fc291 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -78,6 +78,18 @@ "expression": "sum(orders.o_totalprice_double)", }, ], + "rowLevelAccessControls": [ + { + "name": "customer_access", + "requiredProperties": [ + { + "name": "session_user", + "required": False, + } + ], + "condition": "c_name = @session_user", + }, + ], "primaryKey": "c_custkey", }, ], @@ -478,3 +490,37 @@ async def test_limit_pushdown(client, manifest_str, connection_info): assert response.status_code == 200 result = response.json() assert len(result["data"]) == 10 + + +async def test_rlac_query(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT c_name FROM customer", + }, + headers={ + "x-wren-variables-session_user": "'Customer#000000001'", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "Customer#000000001" + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT c_name FROM customer", + }, + headers={ + "X-WREN-VARIABLES-SESSION_USER": "'Customer#000000001'", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert result["data"][0][0] == "Customer#000000001" From 9685b2802cc948a60b2ca258c1f1352b2d75cac4 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 17:01:24 +0800 Subject: [PATCH 14/30] enhance example --- .../wren-example/data/company/documents.csv | 31 +++ .../wren-example/data/company/tenants.csv | 3 + wren-core/wren-example/data/company/users.csv | 10 + .../examples/row_level_access_control.rs | 74 ------- .../examples/row_level_access_control.rs.rs | 194 ++++++++++++++++++ 5 files changed, 238 insertions(+), 74 deletions(-) create mode 100644 wren-core/wren-example/data/company/documents.csv create mode 100644 wren-core/wren-example/data/company/tenants.csv create mode 100644 wren-core/wren-example/data/company/users.csv delete mode 100644 wren-core/wren-example/examples/row_level_access_control.rs create mode 100644 wren-core/wren-example/examples/row_level_access_control.rs.rs diff --git a/wren-core/wren-example/data/company/documents.csv b/wren-core/wren-example/data/company/documents.csv new file mode 100644 index 000000000..375642d95 --- /dev/null +++ b/wren-core/wren-example/data/company/documents.csv @@ -0,0 +1,31 @@ +id,tenant_id,department,created_by,title,content,status,created_at +d001,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,sales,1001-u1,Sales Q1,Report for Q1,DRAFT,2025-04-01 10:00:00 +d002,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,sales,1001-u1,Sales Q2,Final Q2 numbers,PUBLISHED,2025-04-10 12:00:00 +d003,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,hr,1002-u2,HR Policy,Draft HR policy,DRAFT,2025-04-05 09:30:00 +d004,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,engineering,1003-u3,Infra Notes,Infra scaling plan,PUBLISHED,2025-04-12 15:00:00 +d005,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,finance,1004-u4,Q1 Financials,Q1 Revenue,PUBLISHED,2025-04-03 11:20:00 +d006,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,product,1005-u5,Feature Spec,Spec for v2,REJECTED,2025-04-07 14:00:00 +d007,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,sales,2001-u6,Q1 Leads,Lead list,PUBLISHED,2025-04-08 14:20:00 +d008,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,engineering,2002-u7,Dev Plan,Backend roadmap,DRAFT,2025-04-09 09:00:00 +d009,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,hr,2003-u8,Policy Memo,Hiring freeze,PUBLISHED,2025-04-11 16:00:00 +d010,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,marketing,2004-u9,Campaign A,Ad campaign A,REJECTED,2025-04-13 17:45:00 +d011,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,finance,2002-u7,Doc 11,Content of doc 11,PUBLISHED,2025-04-15 16:00:00 +d012,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,product,1001-u1,Doc 12,Content of doc 12,PUBLISHED,2025-04-09 14:00:00 +d013,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,engineering,1003-u3,Doc 13,Content of doc 13,PUBLISHED,2025-04-14 14:00:00 +d014,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,marketing,1004-u4,Doc 14,Content of doc 14,PUBLISHED,2025-04-12 10:00:00 +d015,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,engineering,1003-u3,Doc 15,Content of doc 15,DRAFT,2025-04-08 18:00:00 +d016,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,sales,2002-u7,Doc 16,Content of doc 16,PUBLISHED,2025-04-14 16:00:00 +d017,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,hr,2003-u8,Doc 17,Content of doc 17,PUBLISHED,2025-04-12 08:00:00 +d018,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,finance,2003-u8,Doc 18,Content of doc 18,REJECTED,2025-04-11 18:00:00 +d019,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,finance,1003-u3,Doc 19,Content of doc 19,PUBLISHED,2025-04-10 11:00:00 +d020,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,hr,2003-u8,Doc 20,Content of doc 20,REJECTED,2025-04-06 16:00:00 +d021,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,engineering,2002-u7,Doc 21,Content of doc 21,REJECTED,2025-04-02 12:00:00 +d022,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,hr,1001-u1,Doc 22,Content of doc 22,REJECTED,2025-04-08 17:00:00 +d023,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,product,1002-u2,Doc 23,Content of doc 23,DRAFT,2025-04-01 18:00:00 +d024,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,product,1001-u1,Doc 24,Content of doc 24,PUBLISHED,2025-04-11 09:00:00 +d025,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,finance,1004-u4,Doc 25,Content of doc 25,REJECTED,2025-04-14 13:00:00 +d026,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,engineering,2002-u7,Doc 26,Content of doc 26,REJECTED,2025-04-02 09:00:00 +d027,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,sales,2004-u9,Doc 27,Content of doc 27,REJECTED,2025-04-13 09:00:00 +d028,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,engineering,2002-u7,Doc 28,Content of doc 28,DRAFT,2025-04-07 10:00:00 +d029,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,hr,1002-u2,Doc 29,Content of doc 29,REJECTED,2025-04-04 14:00:00 +d030,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,product,1002-u2,Doc 30,Content of doc 30,REJECTED,2025-04-04 08:00:00 diff --git a/wren-core/wren-example/data/company/tenants.csv b/wren-core/wren-example/data/company/tenants.csv new file mode 100644 index 000000000..e6cd2d1f9 --- /dev/null +++ b/wren-core/wren-example/data/company/tenants.csv @@ -0,0 +1,3 @@ +id,name +1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,Acme Corp +2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,Globex Inc \ No newline at end of file diff --git a/wren-core/wren-example/data/company/users.csv b/wren-core/wren-example/data/company/users.csv new file mode 100644 index 000000000..99888bb0d --- /dev/null +++ b/wren-core/wren-example/data/company/users.csv @@ -0,0 +1,10 @@ +id,email,name,tenant_id,department,role +1001-u1,user1@acme.com,Alice,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,sales,MEMBER +1002-u2,user2@acme.com,Bob,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,hr,MEMBER +1003-u3,admin@acme.com,Carol,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,engineering,ADMIN +1004-u4,user3@acme.com,David,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,finance,MEMBER +1005-u5,user4@acme.com,Eve,1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa,product,MEMBER +2001-u6,user5@globex.com,Frank,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,sales,MEMBER +2002-u7,user6@globex.com,Grace,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,engineering,MEMBER +2003-u8,admin@globex.com,Hank,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,hr,ADMIN +2004-u9,user7@globex.com,Ivy,2bcdef02-bbbb-bbbb-bbbb-bbbbbbbbbbbb,marketing,MEMBER diff --git a/wren-core/wren-example/examples/row_level_access_control.rs b/wren-core/wren-example/examples/row_level_access_control.rs deleted file mode 100644 index 6123351a8..000000000 --- a/wren-core/wren-example/examples/row_level_access_control.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; - -use datafusion::prelude::{CsvReadOptions, SessionContext}; -use wren_core::mdl::builder::{ColumnBuilder, ManifestBuilder, ModelBuilder}; -use wren_core::mdl::manifest::{Manifest, SessionProperty}; -use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; - -#[tokio::main] -async fn main() -> datafusion::common::Result<()> { - let manifest = init_manifest(); - let ctx = SessionContext::new(); - - ctx.register_csv( - "customers", - "sqllogictest/tests/resources/ecommerce/customers.csv", - CsvReadOptions::new(), - ) - .await?; - let customers_provider = ctx - .catalog("datafusion") - .unwrap() - .schema("public") - .unwrap() - .table("customers") - .await? - .unwrap(); - let register = HashMap::from([( - "datafusion.public.customers".to_string(), - customers_provider, - )]); - let analyzed_mdl = - Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); - let sql = "SELECT * FROM customers"; - - // carry the seesion property - let mut properties = HashMap::new(); - properties.insert("session_city".to_string(), Some("'Santa Ana'".to_string())); - - let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], properties, sql).await?; - println!("Wren engine generated SQL: \n{}", sql); - let df = match ctx.sql(&sql).await { - Ok(df) => df, - Err(e) => { - eprintln!("Error: {}", e); - return Err(e); - } - }; - match df.show().await { - Ok(_) => {} - Err(e) => eprintln!("Error: {}", e), - } - - Ok(()) -} - -fn init_manifest() -> Manifest { - ManifestBuilder::new() - .model( - ModelBuilder::new("customers") - .table_reference("datafusion.public.customers") - .column(ColumnBuilder::new("city", "varchar").build()) - .column(ColumnBuilder::new("id", "varchar").build()) - .column(ColumnBuilder::new("state", "varchar").build()) - .add_row_level_access_control( - "city rule", - vec![SessionProperty::new_required("session_city")], - "city = @session_city", - ) - .primary_key("id") - .build(), - ) - .build() -} diff --git a/wren-core/wren-example/examples/row_level_access_control.rs.rs b/wren-core/wren-example/examples/row_level_access_control.rs.rs new file mode 100644 index 000000000..ac68031f7 --- /dev/null +++ b/wren-core/wren-example/examples/row_level_access_control.rs.rs @@ -0,0 +1,194 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use datafusion::prelude::{CsvReadOptions, SessionContext}; +use wren_core::mdl::builder::{ + ColumnBuilder, ManifestBuilder, ModelBuilder, RelationshipBuilder, +}; +use wren_core::mdl::manifest::{JoinType, Manifest, SessionProperty}; +use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; + +/// It's an example to show how to use wren engine to set up row level access control +/// for a multi-tenant application. +#[tokio::main] +async fn main() -> datafusion::common::Result<()> { + let manifest = init_manifest(); + let ctx = SessionContext::new(); + + ctx.register_csv( + "documents", + "wren-example/data/company/documents.csv", + CsvReadOptions::new(), + ) + .await?; + let documents = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("documents") + .await? + .unwrap(); + + ctx.register_csv( + "tenants", + "wren-example/data/company/tenants.csv", + CsvReadOptions::new(), + ) + .await?; + let tenants = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("tenants") + .await? + .unwrap(); + + ctx.register_csv( + "users", + "wren-example/data/company/users.csv", + CsvReadOptions::new(), + ) + .await?; + let users = ctx + .catalog("datafusion") + .unwrap() + .schema("public") + .unwrap() + .table("users") + .await? + .unwrap(); + + let register = HashMap::from([ + ("datafusion.public.documents".to_string(), documents), + ("datafusion.public.tenants".to_string(), tenants), + ("datafusion.public.users".to_string(), users), + ]); + + let json_str = serde_json::to_string(&manifest).unwrap(); + println!("Manifest JSON: \n{}", json_str); + + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze_with_tables(manifest, register)?); + // carry the seesion property + let mut properties = HashMap::new(); + properties.insert( + "session_tenant_id".to_string(), + Some("'1acdef01-aaaa-aaaa-aaaa-aaaaaaaaaaaa'".to_string()), + ); + properties.insert( + "session_department".to_string(), + Some("'engineering'".to_string()), + ); + properties.insert("session_user_id".to_string(), Some("'1003-u3'".to_string())); + properties.insert("session_role".to_string(), Some("'ADMIN'".to_string())); + + println!("#####################"); + println!( + "session_tenant_id: {}", + &properties + .get("session_tenant_id") + .unwrap() + .clone() + .unwrap() + ); + println!( + "session_department: {}", + &properties + .get("session_department") + .unwrap() + .clone() + .unwrap() + ); + println!( + "session_user_id: {}", + &properties.get("session_user_id").unwrap().clone().unwrap() + ); + println!( + "session_role: {}", + &properties.get("session_role").unwrap().clone().unwrap() + ); + + let sql = "select * from wren.test.documents"; + let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], properties, sql).await?; + let df = match ctx.sql(&sql).await { + Ok(df) => df, + Err(e) => { + eprintln!("Error: {}", e); + return Err(e); + } + }; + match df.show().await { + Ok(_) => {} + Err(e) => eprintln!("Error: {}", e), + } + + println!("#####################"); + println!("Wren engine generated SQL: \n{}", sql); + + Ok(()) +} + +fn init_manifest() -> Manifest { + ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model(ModelBuilder::new("tenants") + .table_reference("datafusion.public.tenants") + .column(ColumnBuilder::new("id", "string").build()) + .column(ColumnBuilder::new("name", "string").build()) + .primary_key("id") + .build()) + .model(ModelBuilder::new("users") + .table_reference("datafusion.public.users") + .column(ColumnBuilder::new("id", "string").build()) + .column(ColumnBuilder::new("email", "string").build()) + .column(ColumnBuilder::new("tenant_id", "string").build()) + .column(ColumnBuilder::new("name", "string").build()) + .column(ColumnBuilder::new("role", "string").build()) + .column(ColumnBuilder::new("department", "string").build()) + .column(ColumnBuilder::new("tenants", "tenants") + .relationship("tenants_users") + .build()) + .primary_key("id") + .build()) + .model(ModelBuilder::new("documents") + .table_reference("datafusion.public.documents") + .column(ColumnBuilder::new("id", "string").build()) + .column(ColumnBuilder::new("tenant_id", "string").build()) + .column(ColumnBuilder::new("department", "string").build()) + .column(ColumnBuilder::new("created_by", "string").build()) + .column(ColumnBuilder::new("title", "string").build()) + .column(ColumnBuilder::new("content", "string").build()) + .column(ColumnBuilder::new("status", "string").build()) + .column(ColumnBuilder::new("created_at", "timestamp").build()) + // This is a row level access control allow the user to see the documents in the following rules: + // 1. The user only can see the documents in his tenant + .add_row_level_access_control("multitenant", vec![SessionProperty::new_required("session_tenant_id")], "tenant_id = @session_tenant_id") + // This is a row level access control allow the user to see the documents in the following rules: + // 1. Member only can see the documents created by himself or the documents with status 'PUBLIC' in his department + // 2. Admin can see all the documents + .add_row_level_access_control("auth", vec![ + SessionProperty::new_optional("session_role", Some("MEMBER".to_string())), + SessionProperty::new_required("session_department"), + SessionProperty::new_required("session_user_id")], + "@session_role = 'ADMIN' OR (department = @session_department AND (created_by = @session_user_id OR status = 'PUBLIC'))") + .build()) + .relationship(RelationshipBuilder::new("tenants_users").model("tenants") + .model("users") + .join_type(JoinType::OneToMany) + .condition("tenants.id = users.tenant_id") + .build()) + .relationship(RelationshipBuilder::new("users_documents").model("users") + .model("documents") + .join_type(JoinType::OneToMany) + .condition("users.id = documents.created_by") + .build()) + .relationship(RelationshipBuilder::new("tenants_documents").model("tenants") + .model("documents") + .join_type(JoinType::OneToMany) + .condition("tenants.id = documents.tenant_id") + .build()) + .build() +} From 4265c57869f266a890abb31b36bdb1b6de33fe3a Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 17:04:21 +0800 Subject: [PATCH 15/30] remove unused file --- ibis-server/tools/query_local_run-v2.py | 90 ------------------------- 1 file changed, 90 deletions(-) delete mode 100644 ibis-server/tools/query_local_run-v2.py diff --git a/ibis-server/tools/query_local_run-v2.py b/ibis-server/tools/query_local_run-v2.py deleted file mode 100644 index 31a8732b8..000000000 --- a/ibis-server/tools/query_local_run-v2.py +++ /dev/null @@ -1,90 +0,0 @@ -# -# The script below is a standalone script that can be used to run a SQL query locally. -# -# Argements: -# - sql: stdin input a SQL query -# -# Environment variables: -# - WREN_MANIFEST_JSON_PATH: path to the manifest JSON file -# - REMOTE_FUNCTION_LIST_PATH: path to the function list file -# - CONNECTION_INFO_PATH: path to the connection info file -# - DATA_SOURCE: data source name -# - -import base64 -import json -import os -import sqlglot -import sys - -from dotenv import load_dotenv -from wren_core import SessionContext -from app.mdl.java_engine import JavaEngineConnector -from app.model.data_source import BigQueryConnectionInfo, DataSource -from app.model.data_source import DataSourceExtension -from app.mdl.rewriter import Rewriter - -if sys.stdin.isatty(): - print("please provide the SQL query via stdin, e.g. `python query_local_run.py < test.sql`", file=sys.stderr) - sys.exit(1) - -sql = sys.stdin.read() - - -load_dotenv() -manifest_json_path = os.getenv("WREN_MANIFEST_JSON_PATH") -function_list_path = os.getenv("REMOTE_FUNCTION_LIST_PATH") -connection_info_path = os.getenv("CONNECTION_INFO_PATH") -data_source = os.getenv("DATA_SOURCE") - -# Welcome message -print("### Welcome to the Wren Core Query Runner ###") -print("#") -print("# Manifest JSON Path:", manifest_json_path) - -async def main(): - print("# Function List Path:", function_list_path) - print("# Connection Info Path:", connection_info_path) - print("# Data Source:", data_source) - print("# SQL Query:\n", sql) - print("#") - - # Read and encode the JSON data - with open(manifest_json_path) as file: - mdl = json.load(file) - # Convert to JSON string - json_str = json.dumps(mdl) - # Encode to base64 - encoded_str = base64.b64encode(json_str.encode("utf-8")).decode("utf-8") - - with open(connection_info_path) as file: - connection_info = json.load(file) - - print("### Starting the session context ###") - print("#") - rewriter = Rewriter(encoded_str, - data_source=DataSource[data_source], - java_engine_connector=JavaEngineConnector(os.getenv("WREN_ENGINE_ENDPOINT"))) - # session_context = SessionContext(encoded_str, function_list_path) - # planned_sql = session_context.transform_sql(sql) - planned_sql = await rewriter.rewrite(sql) - print("# Planned SQL:\n", planned_sql) - - # Transpile the planned SQL - # dialect_sql = sqlglot.transpile(planned_sql, read="trino", write=data_source)[0] - # print("# Dialect SQL:\n", dialect_sql) - print("#") - - if data_source == "bigquery": - connection_info = BigQueryConnectionInfo.model_validate_json(json.dumps(connection_info)) - connection = DataSourceExtension.get_bigquery_connection(connection_info) - df = connection.sql(planned_sql).limit(10).to_pandas() - print("### Result ###") - print("") - print(df) - else: - print("Unsupported data source:", data_source) - -if __name__ == "__main__": - import asyncio - asyncio.run(main()) \ No newline at end of file From 8a4035d3d6830b0113e26d1325f69bc534db0bc1 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 17:07:23 +0800 Subject: [PATCH 16/30] remove unused file --- wren-core/wren-example/examples/demo_site.rs | 128 ------------- .../wren-example/examples/plan-sql-json.rs | 169 ------------------ 2 files changed, 297 deletions(-) delete mode 100644 wren-core/wren-example/examples/demo_site.rs delete mode 100644 wren-core/wren-example/examples/plan-sql-json.rs diff --git a/wren-core/wren-example/examples/demo_site.rs b/wren-core/wren-example/examples/demo_site.rs deleted file mode 100644 index 9ea30d829..000000000 --- a/wren-core/wren-example/examples/demo_site.rs +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use datafusion::common::Result; -use datafusion::config::ConfigOptions; -use datafusion::execution::{FunctionRegistry, SessionStateBuilder}; -use datafusion::logical_expr::ScalarUDF; -use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion::sql::sqlparser::ast::Visit; -use datafusion::sql::sqlparser::dialect::GenericDialect; -use std::ops::ControlFlow; -use std::sync::Arc; -use std::{fs, io}; -use wren_core::logical_plan::utils::try_map_data_type; -use wren_core::mdl::function::ByPassScalarUDF; -use wren_core::mdl::manifest::Manifest; -use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; - -#[tokio::main] -async fn main() -> Result<()> { - env_logger::init(); - - let mdl_json = "/Users/jax/git/wren-engine/ibis-server/etc.local/local_mdl.json"; - let json_string = fs::read_to_string(mdl_json).unwrap(); - let manifest: Manifest = serde_json::from_str(&json_string).unwrap(); - let mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - // - // let sql = "SELECT orders_key FROM (select * from orders limit 1000) as t"; - // let mut statements = wren_core::parser::Parser::parse_sql(&GenericDialect {}, sql)?; - // let pushdown_limit = 100; - // - // visit_statements_mut(&mut statements, |stmt| { - // if let Statement::Query(q) = stmt { - // if let Some(limit) = &q.limit { - // if let Expr::Value(Value::Number(n, is)) = limit { - // if n.parse::().unwrap() > pushdown_limit { - // q.limit = Some(Expr::Value(Value::Number( - // pushdown_limit.to_string(), - // is.clone(), - // ))); - // } - // } - // } else { - // q.limit = Some(Expr::Value(Value::Number( - // pushdown_limit.to_string(), - // false, - // ))); - // } - // } - // ControlFlow::<()>::Continue(()) - // }); - // print!("{}", statements[0]); - // let ctx = SessionContext::new(); - // let unparsed = match transform_sql_with_ctx(&ctx, mdl, &[], sql).await { - // Ok(sql) => println!("{}", sql), - // Err(e) => { - // eprintln!("Error: {}", e); - // return Ok(()); - // } - // }; - // let mut config = ConfigOptions::new(); - // config.execution.time_zone = Some("+03:00".to_string()); - // let session_config = SessionConfig::from(config); - // let state = SessionStateBuilder::new() - // .with_default_features() - // .with_config(session_config) - // .build(); - // let ctx = SessionContext::from(state); - // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - // "date_diff", - // map_data_type("bigint")?, - // ))); - // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - // "year", - // map_data_type("bigint")?, - // ))); - // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - // "month", - // map_data_type("bigint")?, - // ))); - // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - // "day", - // map_data_type("bigint")?, - // ))); - // ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - // "age", - // map_data_type("interval")?, - // ))); - // - // // let sqls = fs::read_to_string("wren-example/data/demo_site.sql").unwrap(); - // let s = r#" - // select timestamp with time zone '2011-01-01' - // "#; - // // for (i, sql) in sqls.lines().enumerate() { - // let sqls = vec![s]; - // for (i, sql) in sqls.into_iter().enumerate() { - // if sql.starts_with("--") || sql.is_empty() { - // continue; - // } - // match transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], sql).await { - // Ok(sql) => { - // println!("{}", sql); - // } - // Err(e) => { - // println!("{}: {}", i + 1, sql); - // eprintln!("Error: {}", e); - // return Ok(()); - // } - // }; - // } - Ok(()) -} diff --git a/wren-core/wren-example/examples/plan-sql-json.rs b/wren-core/wren-example/examples/plan-sql-json.rs deleted file mode 100644 index 92adc0f0d..000000000 --- a/wren-core/wren-example/examples/plan-sql-json.rs +++ /dev/null @@ -1,169 +0,0 @@ -use datafusion::common::Result; -use datafusion::execution::SessionStateBuilder; -use datafusion::functions::string::lower; -use datafusion::functions_aggregate::array_agg::array_agg_udaf; -use datafusion::prelude::SessionConfig; -use datafusion::prelude::SessionContext; -use serde::{Deserialize, Serialize}; -use std::str::FromStr; -use std::sync::Arc; -use wren_core::array::AsArray; -use wren_core::array::GenericByteArray; -use wren_core::array::GenericListArray; -use wren_core::datatypes::DataType; -use wren_core::datatypes::GenericStringType; -use wren_core::mdl::function::ByPassScalarUDF; -use wren_core::mdl::function::FunctionType; -use wren_core::mdl::function::RemoteFunction; -use wren_core::ScalarUDF; - -#[tokio::main] -async fn main() -> Result<()> { - env_logger::init(); - let sql = r#" - WITH inputs AS ( - SELECT - r.specific_name, - r.data_type as return_type, - pi.rid, - array_agg(pi.parameter_name order by pi.ordinal_position) as param_names, - array_agg(pi.data_type order by pi.ordinal_position) as param_types - FROM - information_schema.routines r - JOIN - information_schema.parameters pi ON r.specific_name = pi.specific_name AND pi.parameter_mode = 'IN' - GROUP BY 1, 2, 3 - ) - SELECT - r.routine_name as name, - i.param_names, - i.param_types, - r.data_type as return_type, - r.function_type, - r.description - FROM - information_schema.routines r - LEFT JOIN - inputs i ON r.specific_name = i.specific_name - "#; - let config = SessionConfig::new().with_information_schema(true); - let state: datafusion::execution::SessionState = SessionStateBuilder::new() - .with_default_features() - .with_config(config) - .build(); - let ctx = SessionContext::new_with_state(state); - // ctx.register_udaf(Arc::unwrap_or_clone(array_agg_udaf())); - // ctx.register_udf(ScalarUDF::new_from_impl(lower::LowerFunc::new())); - ctx.register_udf(ScalarUDF::new_from_impl(ByPassScalarUDF::new( - "add_two", - DataType::Int64, - ))); - let batches = ctx.sql(sql).await?.collect().await?; - let mut functions = vec![]; - - for batch in batches { - let name_array = batch.column(0).as_string::(); - let param_names_array = batch.column(1).as_list::(); - let param_types_array = batch.column(2).as_list::(); - let return_type_array = batch.column(3).as_string::(); - let function_type_array = batch.column(4).as_string::(); - let description_array = batch.column(5).as_string::(); - - for row in 0..batch.num_rows() { - let name = name_array.value(row).to_string(); - let _param_names = - to_string_vec(param_names_array.value(row).as_string::()); - let _param_types = - to_string_vec(param_types_array.value(row).as_string::()); - let return_type = return_type_array.value(row).to_string(); - let description = description_array.value(row).to_string(); - let function_type = function_type_array.value(row).to_string(); - - functions.push(RemoteFunction { - name, - param_names: None, - param_types: None, - return_type, - description: Some(description), - function_type: FunctionType::from_str(&function_type).unwrap(), - }); - } - } - functions - .iter() - .filter(|f| f.name == "add_two") - .for_each(|f| { - println!("{:?}", f); - }); - Ok(()) -} - -fn to_string_vec( - array: &GenericByteArray>, -) -> Vec> { - array - .iter() - .map(|s| s.map(|s| s.to_string())) - .collect::>>() -} - -fn read_remote_function_list(path: &str) -> Vec { - csv::Reader::from_path(path) - .unwrap() - .into_deserialize::() - .filter_map(Result::ok) - .map(|f| RemoteFunction::from(f)) - .collect::>() -} - -#[derive(Serialize, Deserialize, Clone)] -pub struct PyRemoteFunction { - pub function_type: String, - pub name: String, - pub return_type: Option, - /// It's a comma separated string of parameter names - pub param_names: Option, - /// It's a comma separated string of parameter types - pub param_types: Option, - pub description: Option, -} - -impl From for wren_core::mdl::function::RemoteFunction { - fn from( - remote_function: PyRemoteFunction, - ) -> wren_core::mdl::function::RemoteFunction { - let param_names = remote_function.param_names.map(|names| { - names - .split(",") - .map(|name| { - if name.is_empty() { - None - } else { - Some(name.to_string()) - } - }) - .collect::>>() - }); - let param_types = remote_function.param_types.map(|types| { - types - .split(",") - .map(|t| { - if t.is_empty() { - None - } else { - Some(t.to_string()) - } - }) - .collect::>>() - }); - wren_core::mdl::function::RemoteFunction { - function_type: FunctionType::from_str(&remote_function.function_type) - .unwrap(), - name: remote_function.name, - return_type: remote_function.return_type.unwrap_or("string".to_string()), - param_names, - param_types, - description: remote_function.description, - } - } -} From a1fb5e9dda2c2b4992566bfea67f18547175eb4e Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 17:33:08 +0800 Subject: [PATCH 17/30] rename example --- ..._level_access_control.rs.rs => row-level-access-control.rs.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename wren-core/wren-example/examples/{row_level_access_control.rs.rs => row-level-access-control.rs.rs} (100%) diff --git a/wren-core/wren-example/examples/row_level_access_control.rs.rs b/wren-core/wren-example/examples/row-level-access-control.rs.rs similarity index 100% rename from wren-core/wren-example/examples/row_level_access_control.rs.rs rename to wren-core/wren-example/examples/row-level-access-control.rs.rs From a9747e2f8856f007c4cfa0b04079acd235d80cf5 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 17:41:53 +0800 Subject: [PATCH 18/30] fix fmt --- wren-core-base/manifest-macro/src/lib.rs | 4 +++- .../src/logical_plan/analyze/access_control.rs | 3 ++- wren-core/core/src/mdl/context.rs | 16 +++++++++------- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 24be3c3b1..93606eaae 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -358,7 +358,9 @@ pub fn view(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream } #[proc_macro] -pub fn row_level_access_control(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { +pub fn row_level_access_control( + python_binding: proc_macro::TokenStream, +) -> proc_macro::TokenStream { let input = parse_macro_input!(python_binding as LitBool); let python_binding = if input.value { quote! { diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index 464131546..c798a7915 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -93,7 +93,8 @@ pub fn build_filter_expression( visit_expressions_mut(&mut expr, |expr| { if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { if value.starts_with("@") { - let property_name = value.trim_start_matches("@").to_string().to_lowercase(); + let property_name = + value.trim_start_matches("@").to_string().to_lowercase(); let Some(property_value) = properties.get(&property_name).or_else(|| { required_properties .iter() diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 22d465541..9ea841e25 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -69,13 +69,15 @@ pub async fn create_ctx_with_mdl( ); // ensure all the key in properties is lowercase - let properties = Arc::new(properties - .iter() - .map(|(k, v)| { - let k = k.to_lowercase(); - (k, v.clone()) - }) - .collect::>()); + let properties = Arc::new( + properties + .iter() + .map(|(k, v)| { + let k = k.to_lowercase(); + (k, v.clone()) + }) + .collect::>(), + ); let new_state = if is_local_runtime { new_state.with_analyzer_rules(analyze_rule_for_local_runtime( From fae8664f1ae3dda5b269602aa76d101b40b04ec1 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 24 Apr 2025 17:52:14 +0800 Subject: [PATCH 19/30] fix file name --- ...row-level-access-control.rs.rs => row-level-access-control.rs} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename wren-core/wren-example/examples/{row-level-access-control.rs.rs => row-level-access-control.rs} (100%) diff --git a/wren-core/wren-example/examples/row-level-access-control.rs.rs b/wren-core/wren-example/examples/row-level-access-control.rs similarity index 100% rename from wren-core/wren-example/examples/row-level-access-control.rs.rs rename to wren-core/wren-example/examples/row-level-access-control.rs From 8e9898a84e9bb332c42fa1217bb3f2152613251f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 12:02:27 +0800 Subject: [PATCH 20/30] move insta to dev --- wren-core/core/Cargo.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/wren-core/core/Cargo.toml b/wren-core/core/Cargo.toml index 86fe341ca..cda8bf9aa 100644 --- a/wren-core/core/Cargo.toml +++ b/wren-core/core/Cargo.toml @@ -24,7 +24,6 @@ datafusion = { workspace = true, features = [ "unicode_expressions", ] } env_logger = { workspace = true } -insta = { workspace = true } log = { workspace = true } parking_lot = "0.12.3" petgraph = "0.7.1" @@ -35,3 +34,6 @@ serde_json = { workspace = true } serde_with = { workspace = true } tokio = { workspace = true, features = ["rt", "rt-multi-thread", "macros"] } wren-core-base = { workspace = true } + +[dev-dependencies] +insta = { workspace = true } From cb563ef450927189e0dcb8ee8fb9623467350b97 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 12:03:14 +0800 Subject: [PATCH 21/30] prevent the invalid property value --- .../logical_plan/analyze/access_control.rs | 101 +++++++++++++++++- 1 file changed, 100 insertions(+), 1 deletion(-) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index c798a7915..7c1aae679 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -11,7 +11,10 @@ use datafusion::{ sql::{ parser::DFParserBuilder, sqlparser::{ - ast::{self, visit_expressions, visit_expressions_mut, ExprWithAlias}, + ast::{ + self, visit_expressions, visit_expressions_mut, Array, ExprWithAlias, + Map, MapEntry, + }, dialect::GenericDialect, }, TableReference, @@ -156,9 +159,38 @@ fn parse_expr(expr: &str) -> Result { let dialect = GenericDialect {}; let mut parser = DFParserBuilder::new(expr).with_dialect(&dialect).build()?; let expr = parser.parse_expr()?; + prevent_invalid_expr(&expr.expr)?; Ok(expr) } +/// Prevent invalid expression for the session property. +/// Only literal values are allowed. +fn prevent_invalid_expr(expr: &ast::Expr) -> Result<()> { + match &expr { + ast::Expr::Value(_) | ast::Expr::Interval(_) => Ok(()), + ast::Expr::Array(Array { elem, .. }) => { + for e in elem { + prevent_invalid_expr(e)?; + } + Ok(()) + } + ast::Expr::Map(Map { entries }) => { + for MapEntry { key, value } in entries { + prevent_invalid_expr(key)?; + prevent_invalid_expr(value)?; + } + Ok(()) + } + ast::Expr::Dictionary(fileds) => { + for field in fileds { + prevent_invalid_expr(&field.value)?; + } + Ok(()) + } + _ => plan_err!("The session property {} allow only literal value", expr), + } +} + /// Validate the input headers with the required properties. /// If the result is false, the rules are not satisfied and it should be ignored. /// @@ -541,4 +573,71 @@ mod test { assert_snapshot!(expr_to_sql(&expr)?, @"m1.id = 1 AND m1.\"name\" = 'test'"); Ok(()) } + + #[test] + pub fn test_property_value() -> Result<()> { + let ctx = SessionContext::new(); + let state = ctx.state_ref(); + let model = ModelBuilder::new("m1") + .column(ColumnBuilder::new("id", "int").build()) + .column(ColumnBuilder::new("name", "varchar").build()) + .build(); + + let rule = RowLevelAccessControl { + condition: "id = @session_id".to_string(), + required_properties: vec![SessionProperty::new_required("SESSION_ID")], + name: "test".to_string(), + }; + + let valid_values = vec![ + "1", + "'aaa'", + "1.0", + "true", + "false", + "[1,2,3]", + "{'key': 'value'}", + "{key: 'value'}", + "INTERVAL '1' YEAR", + ]; + + for value in valid_values { + let headers: Arc>> = Arc::new(build_headers( + &[("session_id".to_string(), Some(value.to_string()))], + )); + + let expr = + build_filter_expression(&state, Arc::clone(&model), &headers, &rule)?; + expr_to_sql(&expr)?; + } + + let invalid_values = vec![ + "1 + 1", + "upper('aaa')", + "(select 1)", + "1 or 1", + "aaa", + "is null", + "is not null", + "case when 1 then 1 else 2 end", + "[upper('aaa'), upper('aaa')]", + "{'key': upper('aaa')}", + "{ key: upper('aaa') }", + ]; + + for value in invalid_values { + let headers: Arc>> = Arc::new(build_headers( + &[("session_id".to_string(), Some(value.to_string()))], + )); + + match build_filter_expression(&state, Arc::clone(&model), &headers, &rule) { + Err(_) => {} + _ => panic!( + "should be error: {}", + &headers.get("session_id").unwrap().as_ref().unwrap() + ), + } + } + Ok(()) + } } From 6ca5719373160c6b3334d2828c8bd8a5d2d7c76a Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 12:06:33 +0800 Subject: [PATCH 22/30] use reference for rlac --- wren-core-base/manifest-macro/src/lib.rs | 2 +- wren-core-base/src/mdl/builder.rs | 2 +- wren-core-base/src/mdl/manifest.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index 93606eaae..4d1e502aa 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -141,7 +141,7 @@ pub fn model(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream #[serde(default)] pub refresh_time: Option, #[serde(default)] - pub row_level_access_controls: Vec, + pub row_level_access_controls: Vec>, } }; proc_macro::TokenStream::from(expanded) diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index da8261193..01298d322 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -162,7 +162,7 @@ impl ModelBuilder { required_properties, condition: condition.to_string(), }; - self.model.row_level_access_controls.push(rule); + self.model.row_level_access_controls.push(Arc::new(rule)); self } diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 460a28672..e168d1e25 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -278,7 +278,7 @@ impl Model { self.table_reference.as_deref().unwrap_or("") } - pub fn row_level_access_controls(&self) -> &[RowLevelAccessControl] { + pub fn row_level_access_controls(&self) -> &[Arc] { &self.row_level_access_controls } } From 6f86e14c78804ee38292babff264cd30804d4f5e Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 12:19:39 +0800 Subject: [PATCH 23/30] refactor and update lock --- ibis-server/app/dependencies.py | 10 +++ ibis-server/app/mdl/rewriter.py | 9 +- ibis-server/app/routers/v3/connector.py | 7 +- .../v3/connector/postgres/test_fallback_v2.py | 4 +- .../v3/connector/postgres/test_query.py | 5 +- wren-core-py/Cargo.lock | 89 ------------------- 6 files changed, 26 insertions(+), 98 deletions(-) diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index 84268e163..f5a93a65f 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -5,6 +5,7 @@ from app.model.data_source import DataSource X_WREN_FALLBACK_DISABLE = "x-wren-fallback_disable" +X_WREN_VARIABLE_PREFIX = "x-wren-variable-" # Rebuild model to validate the dto is correct via validation of the pydantic @@ -35,3 +36,12 @@ def _filter_headers(header_string: str) -> bool: elif header_string == "sentry-trace": return True return False + return request.headers + + +def exist_wren_variables_header( + headers: Headers, +) -> bool: + if headers is None: + return False + return any(key.startswith(X_WREN_VARIABLE_PREFIX) for key in headers.keys()) diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index 468c72929..bfb8353d7 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -7,6 +7,7 @@ from opentelemetry import trace from app.config import get_config +from app.dependencies import X_WREN_VARIABLE_PREFIX from app.mdl.core import ( get_manifest_extractor, get_session_context, @@ -133,13 +134,13 @@ async def rewrite( def get_session_properties(self, properties: dict) -> dict | None: if properties is None: return None - # filter the properties which name starts with "x-wren-variables-" - # and remove the prefix "x-wren-variables-" + # filter the properties which name starts with "x-wren-variable-" + # and remove the prefix "x-wren-variable-" return { - k.replace("x-wren-variables-", ""): v + k.replace(X_WREN_VARIABLE_PREFIX, ""): v for k, v in properties.items() - if k.startswith("x-wren-variables-") + if k.startswith(X_WREN_VARIABLE_PREFIX) } @staticmethod diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 81416fe0c..a2c633c7b 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -7,7 +7,12 @@ from starlette.datastructures import Headers from app.config import get_config -from app.dependencies import X_WREN_FALLBACK_DISABLE, get_wren_headers, verify_query_dto +from app.dependencies import ( + X_WREN_FALLBACK_DISABLE, + exist_wren_variables_header, + get_wren_headers, + verify_query_dto, +) from app.mdl.core import get_session_context from app.mdl.java_engine import JavaEngineConnector from app.mdl.rewriter import Rewriter diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py index a37ba363a..bf2c1226e 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py @@ -3,7 +3,7 @@ import orjson import pytest -from app.dependencies import X_WREN_FALLBACK_DISABLE +from app.dependencies import X_WREN_FALLBACK_DISABLE, X_WREN_VARIABLE_PREFIX from tests.routers.v3.connector.postgres.conftest import base_url # It's not a valid manifest for v3. We expect the query to fail and fallback to v2. @@ -393,6 +393,6 @@ async def test_query_rlac(client, manifest_str, connection_info): "manifestStr": manifest_str, "sql": "SELECT orderkey FROM orders LIMIT 1", }, - headers={"x-wren-variables-session_user": "1"}, + headers={X_WREN_VARIABLE_PREFIX + "session_user": "1"}, ) assert response.status_code == 422 diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index 3864fc291..cf6f51e70 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -3,6 +3,7 @@ import orjson import pytest +from app.dependencies import X_WREN_VARIABLE_PREFIX from tests.routers.v3.connector.postgres.conftest import base_url manifest = { @@ -501,7 +502,7 @@ async def test_rlac_query(client, manifest_str, connection_info): "sql": "SELECT c_name FROM customer", }, headers={ - "x-wren-variables-session_user": "'Customer#000000001'", + X_WREN_VARIABLE_PREFIX + "session_user": "'Customer#000000001'", }, ) assert response.status_code == 200 @@ -517,7 +518,7 @@ async def test_rlac_query(client, manifest_str, connection_info): "sql": "SELECT c_name FROM customer", }, headers={ - "X-WREN-VARIABLES-SESSION_USER": "'Customer#000000001'", + X_WREN_VARIABLE_PREFIX + "SESSION_USER": "'Customer#000000001'", }, ) assert response.status_code == 200 diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index 46e6872bc..e94697b30 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -594,30 +594,6 @@ dependencies = [ "unicode-width", ] -[[package]] -name = "console" -version = "0.15.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" -dependencies = [ - "encode_unicode", - "libc", - "once_cell", - "windows-sys", -] - -[[package]] -name = "console" -version = "0.15.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" -dependencies = [ - "encode_unicode", - "libc", - "once_cell", - "windows-sys", -] - [[package]] name = "const-random" version = "0.1.18" @@ -1253,18 +1229,6 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" -[[package]] -name = "encode_unicode" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" - -[[package]] -name = "encode_unicode" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" - [[package]] name = "env_filter" version = "0.1.3" @@ -1750,19 +1714,6 @@ version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" -[[package]] -name = "insta" -version = "1.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" -dependencies = [ - "console", - "linked-hash-map", - "once_cell", - "pin-project", - "similar", -] - [[package]] name = "integer-encoding" version = "3.0.4" @@ -1925,12 +1876,6 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" - [[package]] name = "linux-raw-sys" version = "0.9.3" @@ -2288,26 +2233,6 @@ dependencies = [ "siphasher", ] -[[package]] -name = "pin-project" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2747,18 +2672,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" -[[package]] -name = "similar" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" - -[[package]] -name = "similar" -version = "2.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" - [[package]] name = "siphasher" version = "1.0.1" @@ -3435,8 +3348,6 @@ dependencies = [ "csv", "datafusion", "env_logger", - "insta", - "insta", "log", "parking_lot", "petgraph 0.7.1", From 22bc47938518e4fb34cc3332076987fc37ed4704 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 13:07:41 +0800 Subject: [PATCH 24/30] fix example default --- wren-core/wren-example/examples/row-level-access-control.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wren-core/wren-example/examples/row-level-access-control.rs b/wren-core/wren-example/examples/row-level-access-control.rs index ac68031f7..523bf21b1 100644 --- a/wren-core/wren-example/examples/row-level-access-control.rs +++ b/wren-core/wren-example/examples/row-level-access-control.rs @@ -170,7 +170,7 @@ fn init_manifest() -> Manifest { // 1. Member only can see the documents created by himself or the documents with status 'PUBLIC' in his department // 2. Admin can see all the documents .add_row_level_access_control("auth", vec![ - SessionProperty::new_optional("session_role", Some("MEMBER".to_string())), + SessionProperty::new_optional("session_role", Some("'MEMBER'".to_string())), SessionProperty::new_required("session_department"), SessionProperty::new_required("session_user_id")], "@session_role = 'ADMIN' OR (department = @session_department AND (created_by = @session_user_id OR status = 'PUBLIC'))") From 6e627ae595ae0971219177f60d0daf168a9e372f Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 13:16:14 +0800 Subject: [PATCH 25/30] add todo comment --- wren-core/core/src/logical_plan/analyze/access_control.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index 7c1aae679..c7ca2fef9 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -39,6 +39,7 @@ pub fn collect_condition( .build()?; let expr = parser.parse_expr()?; visit_expressions(&expr, |expr| { + // TODO: consider CompoundIdentifier and CompoundFieldAccess if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { if !value.starts_with("@") { if model.get_column(value).is_none() { From b587e3feeb52a2d4c2bf5f9d72b8413a87c17cc8 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 25 Apr 2025 13:26:58 +0800 Subject: [PATCH 26/30] fix typo and check for default is empty --- .../logical_plan/analyze/access_control.rs | 10 ++-- wren-core/core/src/mdl/mod.rs | 48 +++++++++++++++++++ 2 files changed, 53 insertions(+), 5 deletions(-) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index c7ca2fef9..ea78b9524 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -31,7 +31,7 @@ pub fn collect_condition( condition: &str, ) -> Result<(Vec, Vec)> { let mut conditions = vec![]; - let mut seesion_properties: HashSet = HashSet::new(); + let mut session_properties: HashSet = HashSet::new(); let mut error: Option> = None; let dialect = GenericDialect {}; let mut parser = DFParserBuilder::new(condition) @@ -57,8 +57,8 @@ pub fn collect_condition( })); } else { let session_property = value.trim_start_matches("@").to_string(); - if !seesion_properties.contains(&session_property) { - seesion_properties.insert(session_property); + if !session_properties.contains(&session_property) { + session_properties.insert(session_property); } } } @@ -71,7 +71,7 @@ pub fn collect_condition( Ok(( conditions, - seesion_properties.into_iter().collect::>(), + session_properties.into_iter().collect::>(), )) } @@ -215,7 +215,7 @@ pub fn validate_rule( Ok(true) } else { let exist = is_property_present(headers, &property.name); - if exist || property.default_expr.is_some() { + if exist || property.default_expr.as_ref().is_some_and(|expr| !expr.is_empty()) { Ok(true) } else { Ok(false) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index d6bc0f4e8..48a8930c5 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1979,6 +1979,54 @@ mod test { .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_nationkey", "int").build()) + .column(ColumnBuilder::new("c_name", "string").build()) + .add_row_level_access_control( + "nation", + vec![ + // if the default value is empty, it will be skipped + SessionProperty::new_optional( + "session_nation", + Some("".to_string()), + ), + SessionProperty::new_optional( + "session_user", + Some("'Gura'".to_string()), + ), + ], + "c_nationkey = @session_nation and c_name = @session_user", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let headers = + build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" + ); + // the rule is expected to be skipped because the optional property is None without default value + let headers = + build_headers(&[("session_user".to_string(), Some("'Peko'".to_string()))]); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" + ); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + .await?, + @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" + ); Ok(()) } From 59dcc7949cd9f2aa735cdd11ea9af24bd7b3b27c Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 30 Apr 2025 08:47:39 +0800 Subject: [PATCH 27/30] fix compile --- wren-core/core/src/mdl/mod.rs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 48a8930c5..74f09fb9d 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1619,8 +1619,14 @@ mod test { let ctx = SessionContext::new(); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( result, "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ @@ -1662,8 +1668,14 @@ mod test { let ctx = SessionContext::new(); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let result = - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], sql).await?; + let result = transform_sql_with_ctx( + &ctx, + Arc::clone(&analyzed_mdl), + &[], + HashMap::new(), + sql, + ) + .await?; assert_eq!( result, "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM \ From e9bc62337579301e0ee959ee579e1feec4c07c64 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Wed, 30 Apr 2025 09:13:51 +0800 Subject: [PATCH 28/30] use hashset to avoid duplicate result --- .../logical_plan/analyze/access_control.rs | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index ea78b9524..64ecb13dd 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -30,7 +30,7 @@ pub fn collect_condition( model: &Model, condition: &str, ) -> Result<(Vec, Vec)> { - let mut conditions = vec![]; + let mut conditions = HashSet::new(); let mut session_properties: HashSet = HashSet::new(); let mut error: Option> = None; let dialect = GenericDialect {}; @@ -50,7 +50,7 @@ pub fn collect_condition( )); return ControlFlow::Break(()); } - conditions.push(Expr::Column(datafusion::common::Column { + conditions.insert(Expr::Column(datafusion::common::Column { relation: Some(TableReference::bare(model.name())), name: value.to_string(), spans: Spans::new(), @@ -70,7 +70,7 @@ pub fn collect_condition( } Ok(( - conditions, + conditions.into_iter().collect(), session_properties.into_iter().collect::>(), )) } @@ -102,7 +102,9 @@ pub fn build_filter_expression( let Some(property_value) = properties.get(&property_name).or_else(|| { required_properties .iter() - .filter(|r| r.name.to_lowercase() == property_name && !r.required) + .filter(|r| { + !r.required && r.name.eq_ignore_ascii_case(&property_name) + }) .map(|r| &r.default_expr) .next() }) else { @@ -121,7 +123,7 @@ pub fn build_filter_expression( return ControlFlow::Break(()); }; - if property_value.is_empty() { + if property_value.trim().is_empty() { error = Some(plan_err!( "The session property {} should not be empty", property_name @@ -240,7 +242,10 @@ fn is_property_present( #[cfg(test)] mod test { - use std::{collections::HashMap, sync::Arc}; + use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + }; use datafusion::{ error::Result, @@ -278,7 +283,13 @@ mod test { .into_iter() .map(|e| e.schema_name().to_string()) .collect::>(); - assert_eq!(name, vec!["model1.id", "model1.name"]); + let expected: HashSet<&str> = + ["model1.name", "model1.id"].iter().cloned().collect(); + let all_match = name.iter().all(|n| expected.contains(n.as_str())); + + if !all_match { + panic!("should be all match, but got: {:?}", name); + } assert_eq!(session_properties.len(), 1); assert_eq!(session_properties[0], "session_id"); } From 15e307bca5fbd13266d96b64a7f3d1ac153b753d Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Mon, 5 May 2025 10:22:07 +0800 Subject: [PATCH 29/30] fix header check --- ibis-server/app/dependencies.py | 1 - ibis-server/app/routers/v3/connector.py | 1 - ibis-server/app/util.py | 8 -------- 3 files changed, 10 deletions(-) diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index f5a93a65f..69f3a9e5d 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -36,7 +36,6 @@ def _filter_headers(header_string: str) -> bool: elif header_string == "sentry-trace": return True return False - return request.headers def exist_wren_variables_header( diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index a2c633c7b..52300ba0d 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -32,7 +32,6 @@ from app.util import ( append_fallback_context, build_context, - exist_wren_variables_header, pushdown_limit, safe_strtobool, to_json, diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index c70d224e7..b5e2ccb55 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -147,11 +147,3 @@ def get_fallback_message( def safe_strtobool(val: str) -> bool: return val.lower() in {"1", "true", "yes", "y"} - - -def exist_wren_variables_header( - headers: Header, -) -> bool: - if headers is None: - return False - return any(key.startswith("x-wren-variables-") for key in headers.keys()) From 1c2b05d8dce4ab0c06ad50dd886ac72c45ba2bc9 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 6 May 2025 10:20:04 +0800 Subject: [PATCH 30/30] fix missing header --- ibis-server/app/routers/v3/connector.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 52300ba0d..8c0da3f9a 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -110,7 +110,10 @@ async def query( else: sql = pushdown_limit(dto.sql, limit) rewritten_sql = await Rewriter( - dto.manifest_str, data_source=data_source, experiment=True + dto.manifest_str, + data_source=data_source, + experiment=True, + properties=dict(headers), ).rewrite(sql) connector = Connector(data_source, dto.connection_info) result = connector.query(rewritten_sql, limit=limit)