diff --git a/ibis-server/app/mdl/core.py b/ibis-server/app/mdl/core.py index 27df4d10e..e23d7b750 100644 --- a/ibis-server/app/mdl/core.py +++ b/ibis-server/app/mdl/core.py @@ -5,9 +5,9 @@ @cache def get_session_context( - manifest_str: str | None, function_path: str + manifest_str: str | None, function_path: str, properties: frozenset | None = None ) -> wren_core.SessionContext: - return wren_core.SessionContext(manifest_str, function_path) + return wren_core.SessionContext(manifest_str, function_path, properties) def get_manifest_extractor(manifest_str: str) -> wren_core.ManifestExtractor: diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index bfb8353d7..7155e8573 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -122,27 +122,31 @@ async def rewrite( self, manifest_str: str, sql: str, properties: dict | None = None ) -> str: try: - session_context = get_session_context(manifest_str, self.function_path) + processed_properties = self.get_session_properties(properties) + session_context = get_session_context( + manifest_str, self.function_path, processed_properties + ) return await to_thread.run_sync( session_context.transform_sql, sql, - self.get_session_properties(properties), ) except Exception as e: raise RewriteError(str(e)) - def get_session_properties(self, properties: dict) -> dict | None: + def get_session_properties(self, properties: dict) -> frozenset | None: if properties is None: return None # filter the properties which name starts with "x-wren-variable-" # and remove the prefix "x-wren-variable-" - return { + processed_properties = { k.replace(X_WREN_VARIABLE_PREFIX, ""): v for k, v in properties.items() if k.startswith(X_WREN_VARIABLE_PREFIX) } + return frozenset(processed_properties.items()) + @staticmethod def handle_extract_exception(e: Exception): raise RewriteError(str(e)) diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index cf6f51e70..e82d38881 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -70,7 +70,21 @@ }, "columns": [ {"name": "c_custkey", "type": "integer"}, - {"name": "c_name", "type": "varchar"}, + { + "name": "c_name", + "type": "varchar", + "columnLevelAccessControl": { + "name": "c_name_access", + "requiredProperties": [ + { + "name": "session_level", + "required": False, + } + ], + "operator": "EQUALS", + "threshold": "1", + }, + }, {"name": "orders", "type": "orders", "relationship": "orders_customer"}, { "name": "sum_totalprice", @@ -525,3 +539,50 @@ async def test_rlac_query(client, manifest_str, connection_info): result = response.json() assert len(result["data"]) == 1 assert result["data"][0][0] == "Customer#000000001" + + +async def test_clac_query(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM customer limit 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert len(result["data"][0]) == 3 + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM customer limit 1", + }, + headers={ + X_WREN_VARIABLE_PREFIX + "session_level": "1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert len(result["data"][0]) == 3 + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM customer limit 1", + }, + headers={ + X_WREN_VARIABLE_PREFIX + "session_level": "2", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert len(result["data"][0]) == 2 diff --git a/wren-core-base/Cargo.toml b/wren-core-base/Cargo.toml index d55cc168e..05fc1d398 100644 --- a/wren-core-base/Cargo.toml +++ b/wren-core-base/Cargo.toml @@ -8,7 +8,7 @@ python-binding = ["dep:pyo3"] default = [] [dependencies] -pyo3 = { version = "0.24.1", features = ["extension-module"], optional = true } +pyo3 = { version = "0.25.0", features = ["extension-module"], optional = true } serde = { version = "1.0.201", features = ["derive", "rc"] } wren-manifest-macro = { path = "manifest-macro" } serde_json = { version = "1.0.117" } diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index ba48d6823..e74e28f28 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -184,7 +184,9 @@ pub fn column(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStrea pub is_hidden: bool, #[deprecated] pub rls: Option, + #[deprecated] pub cls: Option, + pub column_level_access_control: Option>, } }; proc_macro::TokenStream::from(expanded) @@ -466,6 +468,7 @@ pub fn row_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro } #[proc_macro] +#[deprecated] pub fn column_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(python_binding as LitBool); let python_binding = if input.value { @@ -487,6 +490,32 @@ pub fn column_level_security(python_binding: proc_macro::TokenStream) -> proc_ma proc_macro::TokenStream::from(expanded) } +#[proc_macro] +pub fn column_level_access_control( + python_binding: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = parse_macro_input!(python_binding as LitBool); + let python_binding = if input.value { + quote! { + #[pyclass] + } + } else { + quote! {} + }; + let expanded = quote! { + #python_binding + #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)] + #[serde(rename_all = "camelCase")] + pub struct ColumnLevelAccessControl { + pub name: String, + pub required_properties: Vec, + pub operator: ColumnLevelOperator, + pub threshold: NormalizedExpr, + } + }; + proc_macro::TokenStream::from(expanded) +} + #[proc_macro] pub fn column_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream { let input = parse_macro_input!(python_binding as LitBool); diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index 01298d322..9f8eaa29c 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -29,6 +29,8 @@ use crate::mdl::{ }; use std::sync::Arc; +use super::ColumnLevelAccessControl; + /// A builder for creating a Manifest pub struct ManifestBuilder { pub manifest: Manifest, @@ -205,6 +207,7 @@ impl ColumnBuilder { expression: None, rls: None, cls: None, + column_level_access_control: None, }, } } @@ -264,6 +267,23 @@ impl ColumnBuilder { }); self } + + pub fn column_level_access_control( + mut self, + name: &str, + required_properties: Vec, + operator: ColumnLevelOperator, + threshold: &str, + ) -> Self { + self.column.column_level_access_control = Some(Arc::new(ColumnLevelAccessControl { + name: name.to_string(), + required_properties, + operator, + threshold: NormalizedExpr::new(threshold), + })); + self + } + pub fn build(self) -> Arc { Arc::new(self.column) } @@ -437,6 +457,12 @@ mod test { .expression("test") .row_level_security("SESSION_STATUS", RowLevelOperator::Equals) .column_level_security("SESSION_LEVEL", ColumnLevelOperator::Equals, "'NORMAL'") + .column_level_access_control( + "rlac", + vec![SessionProperty::new_required("session_id")], + ColumnLevelOperator::Equals, + "'NORMAL'", + ) .build(); let json_str = serde_json::to_string(&expected).unwrap(); diff --git a/wren-core-base/src/mdl/cls.rs b/wren-core-base/src/mdl/cls.rs index 93c3f7de4..c153bbb8a 100644 --- a/wren-core-base/src/mdl/cls.rs +++ b/wren-core-base/src/mdl/cls.rs @@ -16,7 +16,9 @@ * specific language governing permissions and limitations * under the License. */ -use crate::mdl::manifest::{ColumnLevelSecurity, NormalizedExpr, NormalizedExprType}; +use crate::mdl::manifest::{ + ColumnLevelAccessControl, ColumnLevelSecurity, NormalizedExpr, NormalizedExprType, +}; use crate::mdl::ColumnLevelOperator; use std::fmt::{Display, Formatter}; use std::str::FromStr; @@ -37,6 +39,22 @@ impl ColumnLevelSecurity { } } +impl ColumnLevelAccessControl { + /// Evaluate the input against the column level access control. + /// If the type of the input is different from the type of the value, the result is always false except for NOT_EQUALS. + pub fn eval(&self, input: &str) -> bool { + let input_expr = NormalizedExpr::new(input); + match self.operator { + ColumnLevelOperator::Equals => input_expr.eq(&self.threshold), + ColumnLevelOperator::NotEquals => input_expr.neq(&self.threshold), + ColumnLevelOperator::GreaterThan => input_expr.gt(&self.threshold), + ColumnLevelOperator::LessThan => input_expr.lt(&self.threshold), + ColumnLevelOperator::GreaterThanOrEquals => input_expr.gte(&self.threshold), + ColumnLevelOperator::LessThanOrEquals => input_expr.lte(&self.threshold), + } + } +} + impl NormalizedExpr { pub fn new(expr: &str) -> Self { assert!(!expr.is_empty(), "expr is null or empty"); diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 6bfb8b2a1..fa67cc2d7 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -25,10 +25,10 @@ mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; use manifest_macro::{ - column, column_level_operator, column_level_security, data_source, join_type, manifest, - metric, model, normalized_expr, normalized_expr_type, relationship, - row_level_access_control, row_level_operator, row_level_security, session_property, - time_grain, time_unit, view, + column, column_level_access_control, column_level_operator, column_level_security, + data_source, join_type, manifest, metric, model, normalized_expr, normalized_expr_type, + relationship, row_level_access_control, row_level_operator, row_level_security, + session_property, time_grain, time_unit, view, }; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -47,6 +47,7 @@ mod manifest_impl { time_grain!(false); time_unit!(false); row_level_access_control!(false); + column_level_access_control!(false); session_property!(false); row_level_security!(false); row_level_operator!(false); @@ -62,10 +63,10 @@ mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; use manifest_macro::{ - column, column_level_operator, column_level_security, data_source, join_type, manifest, - metric, model, normalized_expr, normalized_expr_type, relationship, - row_level_access_control, row_level_operator, row_level_security, session_property, - time_grain, time_unit, view, + column, column_level_access_control, column_level_operator, column_level_security, + data_source, join_type, manifest, metric, model, normalized_expr, normalized_expr_type, + relationship, row_level_access_control, row_level_operator, row_level_security, + session_property, time_grain, time_unit, view, }; use pyo3::pyclass; use serde::{Deserialize, Serialize}; @@ -86,6 +87,7 @@ mod manifest_impl { time_unit!(true); manifest!(true); row_level_access_control!(true); + column_level_access_control!(true); session_property!(true); row_level_security!(true); row_level_operator!(true); @@ -294,6 +296,14 @@ impl Column { pub fn expression(&self) -> Option<&str> { self.expression.as_deref() } + + pub fn column_level_access_control(&self) -> Option> { + if let Some(ref cla) = &self.column_level_access_control { + Some(Arc::clone(cla)) + } else { + None + } + } } impl Metric { diff --git a/wren-core-py/Cargo.lock b/wren-core-py/Cargo.lock index fbb053e9a..b811bc411 100644 --- a/wren-core-py/Cargo.lock +++ b/wren-core-py/Cargo.lock @@ -529,9 +529,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.24" +version = "1.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16595d3be041c03b09d08d0858631facccee9221e579704070e6e9e4915d3bc7" +checksum = "d0fc897dc1e865cc67c0e05a836d9d3f1df3cbe442aa4a9473b18e12624a4951" dependencies = [ "jobserver", "libc", @@ -1858,9 +1858,9 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -2059,9 +2059,9 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -2069,9 +2069,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -2293,17 +2293,16 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5203598f366b11a02b13aa20cab591229ff0a89fd121a308a5df751d5fc9219" +checksum = "f239d656363bcee73afef85277f1b281e8ac6212a1d42aa90e55b90ed43c47a4" dependencies = [ - "cfg-if", "indoc", "libc", "memoffset", "once_cell", "portable-atomic", - "pyo3-build-config 0.24.2", + "pyo3-build-config 0.25.0", "pyo3-ffi", "pyo3-macros", "unindent", @@ -2321,9 +2320,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99636d423fa2ca130fa5acde3059308006d46f98caac629418e53f7ebb1e9999" +checksum = "755ea671a1c34044fa165247aaf6f419ca39caa6003aee791a0df2713d8f1b6d" dependencies = [ "once_cell", "target-lexicon 0.13.2", @@ -2331,19 +2330,19 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78f9cf92ba9c409279bc3305b5409d90db2d2c22392d443a87df3a1adad59e33" +checksum = "fc95a2e67091e44791d4ea300ff744be5293f394f1bafd9f78c080814d35956e" dependencies = [ "libc", - "pyo3-build-config 0.24.2", + "pyo3-build-config 0.25.0", ] [[package]] name = "pyo3-macros" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b999cb1a6ce21f9a6b147dcf1be9ffedf02e0043aec74dc390f3007047cecd9" +checksum = "a179641d1b93920829a62f15e87c0ed791b6c8db2271ba0fd7c2686090510214" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -2353,13 +2352,13 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.24.2" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "822ece1c7e1012745607d5cf0bcb2874769f0f7cb34c4cde03b9358eb9ef911a" +checksum = "9dff85ebcaab8c441b0e3f7ae40a6963ecea8a9f5e74f647e33fcf5ec9a1e89e" dependencies = [ "heck", "proc-macro2", - "pyo3-build-config 0.24.2", + "pyo3-build-config 0.25.0", "quote", "syn", ] @@ -2678,18 +2677,18 @@ checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "snafu" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "223891c85e2a29c3fe8fb900c1fae5e69c2e42415e3177752e8718475efa5019" +checksum = "320b01e011bf8d5d7a4a4a4be966d9160968935849c83b918827f6a435e7f627" dependencies = [ "snafu-derive", ] [[package]] name = "snafu-derive" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" +checksum = "1961e2ef424c1424204d3a5d6975f934f56b6d50ff5732382d84ebf460e147f7" dependencies = [ "heck", "proc-macro2", diff --git a/wren-core-py/Cargo.toml b/wren-core-py/Cargo.toml index bc59577b4..2c3ff1c44 100644 --- a/wren-core-py/Cargo.toml +++ b/wren-core-py/Cargo.toml @@ -9,7 +9,7 @@ name = "wren_core_py" crate-type = ["cdylib"] [dependencies] -pyo3 = { version = "0.24.1", features = ["extension-module"] } +pyo3 = { version = "0.25.0", features = ["extension-module"] } wren-core = { path = "../wren-core/core" } wren-core-base = { path = "../wren-core-base", features = ["python-binding"] } base64 = "0.22.1" diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 62b638eeb..311f7ac7b 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -19,6 +19,8 @@ use crate::errors::CoreError; use crate::manifest::to_manifest; use crate::remote_functions::PyRemoteFunction; use log::debug; +use pyo3::types::{PyAnyMethods, PyFrozenSet, PyFrozenSetMethods, PyTuple}; +use pyo3::{PyObject, Python}; use pyo3::{pyclass, pymethods, PyErr, PyResult}; use std::collections::HashMap; use std::hash::Hash; @@ -45,6 +47,7 @@ use wren_core::{ pub struct PySessionContext { ctx: wren_core::SessionContext, mdl: Arc, + properties: Arc>>, runtime: Arc, } @@ -59,6 +62,7 @@ impl Default for PySessionContext { Self { ctx: wren_core::SessionContext::new(), mdl: Arc::new(AnalyzedWrenMDL::default()), + properties: Arc::new(HashMap::new()), runtime: Arc::new(Runtime::new().unwrap()), } } @@ -71,10 +75,11 @@ impl PySessionContext { /// if `mdl_base64` is provided, the session context will be created with the given MDL. Otherwise, an empty MDL will be created. /// if `remote_functions_path` is provided, the session context will be created with the remote functions defined in the CSV file. #[new] - #[pyo3(signature = (mdl_base64=None, remote_functions_path=None))] + #[pyo3(signature = (mdl_base64=None, remote_functions_path=None, properties=None))] pub fn new( mdl_base64: Option<&str>, remote_functions_path: Option<&str>, + properties: Option, ) -> PyResult { let remote_functions = Self::read_remote_function_list(remote_functions_path) .map_err(CoreError::from)?; @@ -112,42 +117,81 @@ impl PySessionContext { return Ok(Self { ctx, mdl: Arc::new(AnalyzedWrenMDL::default()), + properties: Arc::new(HashMap::new()), runtime: Arc::new(runtime), }); }; - let manifest = to_manifest(mdl_base64)?; - - let Ok(analyzed_mdl) = AnalyzedWrenMDL::analyze(manifest) else { - return Err(CoreError::new("Failed to analyze manifest").into()); - }; + Python::with_gil(|py| { + let properties_map = if let Some(obj) = properties { + let obj = obj.as_ref(); + if obj.is_none(py) { + HashMap::new() + } else { + let frozenset = obj.downcast_bound::(py)?; + let mut map = HashMap::new(); + for item in frozenset.iter() { + match item.as_any().clone().downcast_into::() { + Ok(tuple) => { + if tuple.len()? != 2 { + return Err(CoreError::new( + "Properties must be a tuple of (key, value)", + ) + .into()); + } + let key = tuple.get_item(0)?.to_string(); + let value = tuple.get_item(1)?.to_string(); + map.insert(key, Some(value)); + } + Err(_) => { + return Err(CoreError::new( + "Properties must be a tuple of (key, value)", + ) + .into()); + } + } + } + map + } + } else { + HashMap::new() + }; + let manifest = to_manifest(mdl_base64)?; + let properties_ref = Arc::new(properties_map); + let Ok(analyzed_mdl) = + AnalyzedWrenMDL::analyze(manifest, Arc::clone(&properties_ref)) + else { + return Err(CoreError::new("Failed to analyze manifest").into()); + }; - let analyzed_mdl = Arc::new(analyzed_mdl); + let analyzed_mdl = Arc::new(analyzed_mdl); - // the headers won't be used in the context. Provide an empty map. - let ctx = runtime - .block_on(create_ctx_with_mdl( - &ctx, - Arc::clone(&analyzed_mdl), - Arc::new(HashMap::new()), - false, - )) - .map_err(CoreError::from)?; + // the headers won't be used in the context. Provide an empty map. + let ctx = runtime + .block_on(create_ctx_with_mdl( + &ctx, + Arc::clone(&analyzed_mdl), + Arc::new(HashMap::new()), + false, + )) + .map_err(CoreError::from)?; - Ok(Self { - ctx, - mdl: analyzed_mdl, - runtime: Arc::new(runtime), + Ok(Self { + ctx, + mdl: analyzed_mdl, + runtime: Arc::new(runtime), + properties: properties_ref, + }) }) } /// Transform the given Wren SQL to the equivalent Planned SQL. - #[pyo3(signature = (sql=None, properties=None))] + #[pyo3(signature = (sql=None))] pub fn transform_sql( &self, sql: Option<&str>, - properties: Option>>, ) -> PyResult { + env_logger::try_init().ok(); let Some(sql) = sql else { return Err(CoreError::new("SQL is required").into()); }; @@ -158,7 +202,7 @@ impl PySessionContext { // the ctx has been initialized when PySessionContext is created // so we can pass the empty array here &[], - properties.unwrap_or_default(), + Arc::clone(&self.properties), sql, )) .map_err(|e| PyErr::from(CoreError::from(e))) @@ -192,7 +236,7 @@ impl PySessionContext { if statements.len() != 1 { return Err(CoreError::new("Only one statement is allowed").into()); } - visit_statements_mut(&mut statements, |stmt| { + let _ = visit_statements_mut(&mut statements, |stmt| { if let Statement::Query(q) = stmt { if let Some(limit) = &q.limit { if let Expr::Value(Value::Number(n, is)) = limit { diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index 1d7aeea9e..93610ce45 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -26,7 +26,21 @@ }, "columns": [ {"name": "c_custkey", "type": "integer"}, - {"name": "c_name", "type": "varchar"}, + { + "name": "c_name", + "type": "varchar", + "columnLevelAccessControl": { + "name": "c_name_access", + "requiredProperties": [ + { + "name": "session_level", + "required": False, + } + ], + "operator": "EQUALS", + "threshold": "1", + }, + }, {"name": "orders", "type": "orders", "relationship": "orders_customer"}, ], "rowLevelAccessControls": [ @@ -295,9 +309,10 @@ def test_rlac(): headers = { "session_user": "'test_user'", } - session_context = SessionContext(manifest_str, None) + properties_hashable = frozenset(headers.items()) if headers else None + session_context = SessionContext(manifest_str, None, properties_hashable) sql = "SELECT * FROM my_catalog.my_schema.customer" - rewritten_sql = session_context.transform_sql(sql, headers) + rewritten_sql = session_context.transform_sql(sql) assert ( rewritten_sql == "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM main.customer AS __source) AS customer) AS customer WHERE customer.c_name = 'test_user'" @@ -334,3 +349,18 @@ def test_validate_rlac_rule(): str(e.value) == "Exception: DataFusion error: Error during planning: The session property @session_user is used, but not found in the session properties" ) + + +def test_clac(): + headers = { + "session_level": "2", + } + properties_hashable = frozenset(headers.items()) if headers else None + + session_context = SessionContext(manifest_str, None, properties_hashable) + sql = "SELECT * FROM my_catalog.my_schema.customer" + rewritten_sql = session_context.transform_sql(sql) + assert ( + rewritten_sql + == "SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM main.customer AS __source) AS customer) AS customer" + ) diff --git a/wren-core/benchmarks/src/tpch/run.rs b/wren-core/benchmarks/src/tpch/run.rs index aa42fa6c3..7a2669b07 100644 --- a/wren-core/benchmarks/src/tpch/run.rs +++ b/wren-core/benchmarks/src/tpch/run.rs @@ -50,7 +50,10 @@ impl RunOpt { async fn benchmark_query(&self, query_id: usize) -> Result> { let ctx = SessionContext::new(); - let mdl = Arc::new(AnalyzedWrenMDL::analyze(tpch_manifest())?); + let mdl = Arc::new(AnalyzedWrenMDL::analyze( + tpch_manifest(), + Arc::new(HashMap::default()), + )?); let mut millis = vec![]; // run benchmark let mut query_results = vec![]; @@ -62,7 +65,7 @@ impl RunOpt { &ctx, Arc::clone(&mdl), &[], - HashMap::new(), + HashMap::new().into(), query, ) .await?; diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index 16198924c..b9ff587df 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -21,7 +21,7 @@ use datafusion::{ }, }; use wren_core_base::mdl::RowLevelAccessControl; -use wren_core_base::mdl::{Model, SessionProperty}; +use wren_core_base::mdl::{Column, Model, SessionProperty}; use crate::mdl::{context::SessionPropertiesRef, Dataset, SessionStateRef}; @@ -38,7 +38,7 @@ pub fn collect_condition( .with_dialect(&dialect) .build()?; let expr = parser.parse_expr()?; - visit_expressions(&expr, |expr| { + let _ = visit_expressions(&expr, |expr| { // TODO: consider CompoundIdentifier and CompoundFieldAccess if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { if !value.starts_with("@") { @@ -128,7 +128,7 @@ pub fn build_filter_expression( .build()?; let mut expr = parser.parse_expr()?; - visit_expressions_mut(&mut expr, |expr| { + let _ = visit_expressions_mut(&mut expr, |expr| { if let ast::Expr::Identifier(ast::Ident { value, .. }) = expr { if value.starts_with("@") { let property_name = @@ -240,28 +240,70 @@ pub fn validate_rule( required_properties: &[SessionProperty], headers: &HashMap>, ) -> Result { - let exists = required_properties.iter().map(|property| { - if property.required { - if !is_property_present(headers, &property.name) { - return plan_err!( - "Row level access control property {} is required, but not found in headers", - property.name - ); - } - Ok(true) - } else { - let exist = is_property_present(headers, &property.name); - if exist || property.default_expr.as_ref().is_some_and(|expr| !expr.is_empty()) { + let exists = required_properties + .iter() + .map(|property| { + if property.required { + if !is_property_present(headers, &property.name) { + return plan_err!( + "session property {} is required, but not found in headers", + property.name + ); + } Ok(true) } else { - Ok(false) + let exist = is_property_present(headers, &property.name); + if exist + || property + .default_expr + .as_ref() + .is_some_and(|expr| !expr.is_empty()) + { + Ok(true) + } else { + Ok(false) + } } - } - }).collect::>>()?; + }) + .collect::>>()?; Ok(exists.iter().all(|x| *x)) } +pub(crate) fn validate_clac_rule( + column: &Column, + properties: &SessionPropertiesRef, +) -> Result { + let Some(clac) = column.column_level_access_control() else { + return Ok(true); + }; + + if !validate_rule(&clac.required_properties, properties)? { + return Ok(true); + } + + if clac.required_properties.len() > 1 { + return plan_err!( + "Only support one required property for column access-control level rule: {}", + clac.name + ); + } + + let property = &clac.required_properties[0]; + let value_opt = properties.get(&property.name); + + match value_opt { + Some(Some(value)) => Ok(clac.eval(value)), + Some(None) | None => { + if let Some(default) = &property.default_expr { + Ok(clac.eval(default)) + } else { + Ok(true) + } + } + } +} + /// Check if the property is present in the headers and not empty /// If the property is present and not empty, return true. fn is_property_present( @@ -352,7 +394,7 @@ mod test { &build_headers(&[("session_id".to_string(), None)]), ) { Err(error) => { - assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + assert_snapshot!(error.message(), @"session property session_id is required, but not found in headers"); } _ => panic!("should be error"), } @@ -362,7 +404,7 @@ mod test { &build_headers(&[("session_id".to_string(), Some("".to_string()))]), ) { Err(error) => { - assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + assert_snapshot!(error.message(), @"session property session_id is required, but not found in headers"); } _ => panic!("should be error"), } @@ -372,7 +414,7 @@ mod test { &build_headers(&[]), ) { Err(error) => { - assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + assert_snapshot!(error.message(), @"session property session_id is required, but not found in headers"); } _ => panic!("should be error"), } @@ -487,7 +529,7 @@ mod test { ]), ) { Err(error) => { - assert_snapshot!(error.message(), @"Row level access control property session_id is required, but not found in headers"); + assert_snapshot!(error.message(), @"session property session_id is required, but not found in headers"); } _ => panic!("should be error"), } diff --git a/wren-core/core/src/logical_plan/analyze/plan.rs b/wren-core/core/src/logical_plan/analyze/plan.rs index c90a4c33e..a2ba09fb8 100644 --- a/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/wren-core/core/src/logical_plan/analyze/plan.rs @@ -7,8 +7,8 @@ use std::sync::Arc; use datafusion::arrow::datatypes::Field; use datafusion::common::{ - internal_datafusion_err, internal_err, plan_err, Column, DFSchema, DFSchemaRef, - TableReference, + internal_datafusion_err, internal_err, plan_err, Column as DFColumn, DFSchema, + DFSchemaRef, TableReference, }; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::expr::WildcardOptions; @@ -234,7 +234,7 @@ impl ModelPlanNodeBuilder { .entry(model_ref.clone()) .or_default() .insert(OrdExpr::new(expr_plan.clone())); - let expr_plan = Expr::Column(Column::from_qualified_name(format!( + let expr_plan = Expr::Column(DFColumn::from_qualified_name(format!( "{}.{}", quoted(model_ref.table()), quoted(column.name()), @@ -293,7 +293,7 @@ impl ModelPlanNodeBuilder { self.model_required_fields .entry(model.clone()) .or_default() - .insert(OrdExpr::new(Expr::Column(Column::from_qualified_name( + .insert(OrdExpr::new(Expr::Column(DFColumn::from_qualified_name( format!("{}.{}", quoted(model.table()), quoted(pk_column.name()),), )))); } @@ -403,7 +403,7 @@ impl ModelPlanNodeBuilder { !find_aggregate_exprs(&[expr]).is_empty() } - fn is_contain_calculation_source(&self, qualified_column: &Column) -> bool { + fn is_contain_calculation_source(&self, qualified_column: &DFColumn) -> bool { self.analyzed_wren_mdl .lineage() .required_fields_map @@ -424,7 +424,7 @@ impl ModelPlanNodeBuilder { &mut self, model_ref: TableReference, column: Arc, - qualified_column: &Column, + qualified_column: &DFColumn, col_expr: Expr, ) -> Result { let Some(column_graph) = self @@ -518,7 +518,7 @@ fn is_required_column(expr: &Expr, name: &str) -> bool { fn collect_partial_model_plan_for_calculation( analyzed_wren_mdl: Arc, session_state_ref: SessionStateRef, - qualified_column: &Column, + qualified_column: &DFColumn, required_fields: &mut HashMap>, ) -> Result<()> { let Some(set) = analyzed_wren_mdl @@ -562,7 +562,7 @@ fn collect_partial_model_plan_for_calculation( fn collect_partial_model_required_fields( analyzed_wren_mdl: Arc, session_state_ref: SessionStateRef, - qualified_column: &Column, + qualified_column: &DFColumn, required_fields: &mut HashMap>, ) -> Result<()> { let Some(set) = analyzed_wren_mdl @@ -605,7 +605,7 @@ fn collect_partial_model_required_fields( fn collect_model_required_fields( analyzed_wren_mdl: Arc, session_state_ref: SessionStateRef, - qualified_column: &Column, + qualified_column: &DFColumn, required_fields: &mut HashMap>, ) -> Result<()> { let Some(set) = analyzed_wren_mdl diff --git a/wren-core/core/src/logical_plan/context_provider.rs b/wren-core/core/src/logical_plan/context_provider.rs deleted file mode 100644 index f5f0e7b8f..000000000 --- a/wren-core/core/src/logical_plan/context_provider.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use datafusion::arrow::datatypes::DataType; -use datafusion::datasource::DefaultTableSource; -use datafusion::{ - common::{plan_err, Result}, - config::ConfigOptions, - logical_expr::{AggregateUDF, ScalarUDF, TableSource, WindowUDF}, - sql::{planner::ContextProvider, TableReference}, -}; - -use crate::mdl::WrenMDL; - -use super::utils::create_table_source; - -#[deprecated(since = "0.8.0", note = "try to create plan by SessionContext instead")] -/// WrenContextProvider is a ContextProvider implementation that uses the WrenMDL -/// to provide table sources and other metadata. -pub struct WrenContextProvider { - options: ConfigOptions, - tables: HashMap>, -} - -#[allow(deprecated)] -impl WrenContextProvider { - pub fn new(mdl: &WrenMDL) -> Result { - let mut tables = HashMap::new(); - // register model table - for model in mdl.manifest.models.iter() { - tables.insert( - format!("{}.{}.{}", mdl.catalog(), mdl.schema(), model.name()), - create_table_source(model)?, - ); - } - // register physical table - for (name, table) in mdl.register_tables.iter() { - tables.insert( - name.clone(), - Arc::new(DefaultTableSource::new(table.clone())), - ); - } - Ok(Self { - tables, - options: Default::default(), - }) - } - - pub fn new_bare(mdl: &WrenMDL) -> Result { - let mut tables = HashMap::new(); - for model in mdl.manifest.models.iter() { - tables.insert(model.name().to_string(), create_table_source(model)?); - } - Ok(Self { - tables, - options: Default::default(), - }) - } -} - -#[allow(deprecated)] -impl ContextProvider for WrenContextProvider { - fn get_table_source(&self, name: TableReference) -> Result> { - let table_name = name.to_string(); - match self.tables.get(&table_name) { - Some(table) => Ok(table.clone()), - _ => plan_err!("Table not found: {}", &table_name), - } - } - - fn get_function_meta(&self, _name: &str) -> Option> { - None - } - - fn get_aggregate_meta(&self, _name: &str) -> Option> { - None - } - - fn get_variable_type(&self, _variable_names: &[String]) -> Option { - None - } - - fn get_window_meta(&self, _name: &str) -> Option> { - None - } - - fn options(&self) -> &ConfigOptions { - &self.options - } - - fn udf_names(&self) -> Vec { - Vec::new() - } - - fn udaf_names(&self) -> Vec { - Vec::new() - } - - fn udwf_names(&self) -> Vec { - Vec::new() - } -} diff --git a/wren-core/core/src/logical_plan/mod.rs b/wren-core/core/src/logical_plan/mod.rs index fc8a26e7a..21cd4f4c6 100644 --- a/wren-core/core/src/logical_plan/mod.rs +++ b/wren-core/core/src/logical_plan/mod.rs @@ -1,4 +1,3 @@ pub mod analyze; -pub mod context_provider; pub mod optimize; pub mod utils; diff --git a/wren-core/core/src/logical_plan/utils.rs b/wren-core/core/src/logical_plan/utils.rs index c51f6480e..b309f83a4 100644 --- a/wren-core/core/src/logical_plan/utils.rs +++ b/wren-core/core/src/logical_plan/utils.rs @@ -186,11 +186,6 @@ pub fn map_data_type(data_type: &str) -> Result { Ok(result) } -pub fn create_table_source(model: &Model) -> Result> { - let schema = create_schema(model.get_physical_columns())?; - Ok(Arc::new(LogicalTableSource::new(schema))) -} - pub fn create_schema(columns: Vec>) -> Result { let fields: Vec = columns .iter() diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 9ea841e25..e3041184b 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; +use crate::logical_plan::analyze::access_control::validate_clac_rule; use crate::logical_plan::analyze::expand_view::ExpandWrenViewRule; use crate::logical_plan::analyze::model_anlayze::ModelAnalyzeRule; use crate::logical_plan::analyze::model_generation::ModelGenerationRule; @@ -83,7 +84,7 @@ pub async fn create_ctx_with_mdl( new_state.with_analyzer_rules(analyze_rule_for_local_runtime( Arc::clone(&analyzed_mdl), reset_default_catalog_schema.clone(), - properties, + Arc::clone(&properties), )) // The plan will be executed locally, so apply the default optimizer rules } else { @@ -98,7 +99,7 @@ pub async fn create_ctx_with_mdl( let new_state = new_state.with_config(config).build(); let ctx = SessionContext::new_with_state(new_state); - register_table_with_mdl(&ctx, analyzed_mdl.wren_mdl()).await?; + register_table_with_mdl(&ctx, analyzed_mdl.wren_mdl(), properties).await?; Ok(ctx) } @@ -221,6 +222,7 @@ fn optimize_rule_for_unparsing() -> Vec> { pub async fn register_table_with_mdl( ctx: &SessionContext, wren_mdl: Arc, + properties: SessionPropertiesRef, ) -> Result<()> { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -229,7 +231,7 @@ pub async fn register_table_with_mdl( ctx.register_catalog(&wren_mdl.manifest.catalog, Arc::new(catalog)); for model in wren_mdl.manifest.models.iter() { - let table = WrenDataSource::new(Arc::clone(model))?; + let table = WrenDataSource::new(Arc::clone(model), &properties)?; ctx.register_table( TableReference::full(wren_mdl.catalog(), wren_mdl.schema(), model.name()), Arc::new(table), @@ -252,8 +254,22 @@ pub struct WrenDataSource { } impl WrenDataSource { - pub fn new(model: Arc) -> Result { - let schema = create_schema(model.get_physical_columns().clone())?; + pub fn new(model: Arc, properties: &SessionPropertiesRef) -> Result { + let available_columns = model + .get_physical_columns() + .iter() + .map(|column| { + if validate_clac_rule(column, properties)? { + Ok(Some(Arc::clone(column))) + } else { + Ok(None) + } + }) + .collect::>>()? + .into_iter() + .flatten() + .collect::>(); + let schema = create_schema(available_columns)?; Ok(Self { schema }) } diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 8230ba220..2e60ec2b0 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -1,3 +1,4 @@ +use crate::logical_plan::analyze::access_control::validate_clac_rule; use crate::logical_plan::utils::{from_qualified_name_str, try_map_data_type}; use crate::mdl::builder::ManifestBuilder; use crate::mdl::context::{create_ctx_with_mdl, WrenDataSource}; @@ -8,6 +9,7 @@ use crate::mdl::function::{ use crate::mdl::manifest::{Column, Manifest, Metric, Model, View}; use crate::mdl::utils::to_field; use crate::DataFusionError; +use context::SessionPropertiesRef; use datafusion::arrow::datatypes::Field; use datafusion::common::internal_datafusion_err; use datafusion::datasource::TableProvider; @@ -68,8 +70,10 @@ impl Default for AnalyzedWrenMDL { } impl AnalyzedWrenMDL { - pub fn analyze(manifest: Manifest) -> Result { - let wren_mdl = Arc::new(WrenMDL::infer_and_register_remote_table(manifest)?); + pub fn analyze(manifest: Manifest, properties: SessionPropertiesRef) -> Result { + let wren_mdl = Arc::new(WrenMDL::infer_and_register_remote_table( + manifest, properties, + )?); let lineage = Arc::new(lineage::Lineage::new(&wren_mdl)?); Ok(AnalyzedWrenMDL { wren_mdl, lineage }) } @@ -177,23 +181,39 @@ impl WrenMDL { /// Create a WrenMDL from a manifest and register the table reference of the model as a remote table. /// All the column without expression will be considered a column - pub fn infer_and_register_remote_table(manifest: Manifest) -> Result { + pub fn infer_and_register_remote_table( + manifest: Manifest, + properties: SessionPropertiesRef, + ) -> Result { let mut mdl = WrenMDL::new(manifest); let sources: Vec<_> = mdl .models() .iter() .map(|model| { let name = TableReference::from(model.table_reference()); - let fields: Vec<_> = model + let available_columns = model .columns .iter() - .filter_map(|column| Self::infer_source_column(column).ok().flatten()) + .map(|column| { + if validate_clac_rule(column, &properties)? { + Ok(Some(Arc::clone(column))) + } else { + Ok(None) + } + }) + .collect::>>()?; + let fields: Vec<_> = available_columns + .into_iter() + .filter(|c| c.is_some()) + .filter_map(|column| { + Self::infer_source_column(&column.unwrap()).ok().flatten() + }) .collect(); let schema = Arc::new(datafusion::arrow::datatypes::Schema::new(fields)); let datasource = WrenDataSource::new_with_schema(schema); - (name.to_quoted_string(), Arc::new(datasource)) + Ok((name.to_quoted_string(), Arc::new(datasource))) }) - .collect(); + .collect::>>()?; sources .into_iter() .for_each(|(name, ds_ref)| mdl.register_table(name, ds_ref)); @@ -339,7 +359,7 @@ pub fn transform_sql( &SessionContext::new(), analyzed_mdl, remote_functions, - properties, + Arc::new(properties), sql, )) } @@ -351,7 +371,7 @@ pub async fn transform_sql_with_ctx( ctx: &SessionContext, analyzed_mdl: Arc, remote_functions: &[RemoteFunction], - properties: HashMap>, + properties: SessionPropertiesRef, sql: &str, ) -> Result { info!("wren-core received SQL: {}", sql); @@ -360,9 +380,8 @@ pub async fn transform_sql_with_ctx( register_remote_function(ctx, remote_function)?; Ok::<_, DataFusionError>(()) })?; - let properties_ref = Arc::new(properties); - let ctx = create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), properties_ref, false) - .await?; + let ctx = + create_ctx_with_mdl(ctx, Arc::clone(&analyzed_mdl), properties, false).await?; let plan = ctx.state().create_logical_plan(sql).await?; debug!("wren-core original plan:\n {plan}"); let analyzed = ctx.state().optimize(&plan)?; @@ -457,7 +476,7 @@ mod test { use datafusion::sql::unparser::plan_to_sql; use insta::assert_snapshot; use wren_core_base::mdl::{ - DataSource, JoinType, RelationshipBuilder, SessionProperty, + ColumnLevelOperator, DataSource, JoinType, RelationshipBuilder, SessionProperty, }; #[test] @@ -471,7 +490,8 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), }; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let _ = mdl::transform_sql( Arc::clone(&analyzed_mdl), &[], @@ -492,7 +512,8 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), }; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let tests: Vec<&str> = vec![ "select o_orderkey + o_orderkey from test.test.orders", @@ -513,7 +534,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -535,14 +556,15 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), }; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let sql = "select * from test.test.customer_view"; println!("Original: {}", sql); let _ = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -564,13 +586,14 @@ mod test { Ok(mdl) => mdl, Err(e) => return not_impl_err!("Failed to parse mdl json: {}", e), }; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let sql = "select totalcost from profile"; let result = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -581,7 +604,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -605,13 +628,16 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select * from "CTest"."STest"."Customer""#; let actual = mdl::transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -646,12 +672,15 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let actual = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &functions, - HashMap::new(), + Arc::new(HashMap::new()), r#"select add_two("Custkey") from "Customer""#, ) .await?; @@ -662,7 +691,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &functions, - HashMap::new(), + Arc::new(HashMap::new()), r#"select median("Custkey") from "CTest"."STest"."Customer" group by "Name""#, ) .await?; @@ -716,13 +745,16 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select * from wren.test.artist"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -738,7 +770,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -751,7 +783,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -784,13 +816,16 @@ mod test { ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select name_append from wren.test.artist"#; let _ = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await @@ -806,7 +841,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await @@ -839,13 +874,16 @@ mod test { ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select "串接名字" from wren.test.artist"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -856,7 +894,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -868,7 +906,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await.map_err(|e| { @@ -887,7 +925,7 @@ mod test { &SessionContext::new(), Arc::new(AnalyzedWrenMDL::default()), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -911,13 +949,16 @@ mod test { ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select current_date > "出道時間" from wren.test.artist"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -937,7 +978,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -977,7 +1018,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -988,7 +1029,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -999,7 +1040,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1014,19 +1055,37 @@ mod test { async fn test_unnest_as_table_factor() -> Result<()> { let ctx = SessionContext::new(); let manifest = ManifestBuilder::new().build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select * from unnest([1, 2, 3])"; - let actual = - transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + analyzed_mdl, + &[], + Arc::new(HashMap::new()), + sql, + ) + .await?; assert_snapshot!(actual, @"SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM (SELECT UNNEST([1, 2, 3]) AS \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\")"); let manifest = ManifestBuilder::new() .data_source(DataSource::BigQuery) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select * from unnest([1, 2, 3])"; - let actual = - transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; + let actual = transform_sql_with_ctx( + &ctx, + analyzed_mdl, + &[], + Arc::new(HashMap::new()), + sql, + ) + .await?; assert_snapshot!(actual, @"SELECT \"UNNEST(make_array(Int64(1),Int64(2),Int64(3)))\" FROM UNNEST([1, 2, 3])"); Ok(()) } @@ -1040,7 +1099,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1051,7 +1110,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1071,7 +1130,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1083,7 +1142,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1101,7 +1160,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1113,7 +1172,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1145,13 +1204,16 @@ mod test { ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1233,13 +1295,16 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = r#"select timestamp_col = timestamptz_col from wren.test.timestamp_table"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1254,7 +1319,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1268,7 +1333,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1283,7 +1348,7 @@ mod test { &SessionContext::new(), Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1309,13 +1374,16 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select list_col[1] from wren.test.list_table"; let actual = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1350,13 +1418,16 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select struct_col.float_field from wren.test.struct_table"; let actual = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1372,7 +1443,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1386,7 +1457,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1402,9 +1473,12 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select struct_col.float_field from wren.test.struct_table"; - let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + let _ = transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql) .await .map_err(|e| { assert_snapshot!( @@ -1425,7 +1499,7 @@ mod test { &ctx, Arc::new(AnalyzedWrenMDL::default()), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1445,7 +1519,7 @@ mod test { &ctx, Arc::new(AnalyzedWrenMDL::default()), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1462,12 +1536,20 @@ mod test { #[tokio::test] async fn test_dialect_specific_function_rewrite() -> Result<()> { let manifest = ManifestBuilder::default().data_source(MySQL).build(); - let mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let ctx = SessionContext::new(); let sql = "SELECT trim(' abc')"; - let actual = - transform_sql_with_ctx(&ctx, Arc::clone(&mdl), &[], HashMap::new(), sql) - .await?; + let actual = transform_sql_with_ctx( + &ctx, + Arc::clone(&mdl), + &[], + Arc::new(HashMap::new()), + sql, + ) + .await?; assert_snapshot!(actual, @"SELECT trim(' abc')"); Ok(()) } @@ -1487,12 +1569,15 @@ mod test { ) .build(); let sql = r#"SELECT c_custkey, count(distinct c_name) FROM customer GROUP BY c_custkey"#; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let result = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1521,12 +1606,15 @@ mod test { ) .build(); let sql = r#"SELECT c_custkey, (SELECT c_name FROM customer WHERE c_custkey = 1) FROM customer"#; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let result = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1554,12 +1642,15 @@ mod test { ) .build(); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let result = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1603,12 +1694,15 @@ mod test { let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); let ctx = SessionContext::new(); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let result = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1652,12 +1746,15 @@ mod test { let manifest: Manifest = serde_json::from_str(mdl_json).unwrap(); let ctx = SessionContext::new(); let sql = r#"SELECT * FROM customer WHERE c_custkey = 1"#; - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let result = transform_sql_with_ctx( &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await?; @@ -1691,12 +1788,15 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "SELECT * FROM customer"; let headers = build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); assert_snapshot!( - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(headers), sql).await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" ); @@ -1704,7 +1804,7 @@ mod test { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + Arc::new(HashMap::new()), sql, ) .await @@ -1715,7 +1815,7 @@ mod test { @r" ModelAnalyzeRule caused by - Error during planning: Row level access control property session_nation is required, but not found in headers + Error during planning: session property session_nation is required, but not found in headers " ) } @@ -1751,12 +1851,15 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "SELECT * FROM customer"; - let headers = build_headers(&[ + let headers = Arc::new(build_headers(&[ ("session_nation".to_string(), Some("1".to_string())), ("session_user".to_string(), Some("'Gura'".to_string())), - ]); + ])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql,).await?, @"SELECT customer.c_custkey, customer.c_nationkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" @@ -1771,23 +1874,25 @@ mod test { // test other model won't be affected let sql = "SELECT o_orderkey FROM orders"; assert_snapshot!( - transform_sql_with_ctx(&ctx,Arc::clone(&analyzed_mdl),&[],HashMap::new(),sql).await?, + transform_sql_with_ctx(&ctx,Arc::clone(&analyzed_mdl),&[],Arc::new(HashMap::new()),sql).await?, @"SELECT orders.o_orderkey FROM (SELECT orders.o_orderkey FROM (SELECT __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders" ); let sql = "SELECT o_orderkey FROM customer JOIN orders ON customer.c_custkey = orders.o_custkey"; - let headers = build_headers(&[ + let headers = Arc::new(build_headers(&[ ("session_nation".to_string(), Some("1".to_string())), ("session_user".to_string(), Some("'Gura'".to_string())), - ]); + ])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer JOIN (SELECT orders.o_custkey, orders.o_orderkey FROM (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders) AS orders ON customer.c_custkey = orders.o_custkey WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); // test property is required - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); let sql = "SELECT * FROM customer"; match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await @@ -1798,7 +1903,7 @@ mod test { @r" ModelAnalyzeRule caused by - Error during planning: Row level access control property session_user is required, but not found in headers + Error during planning: session property session_user is required, but not found in headers " ) } @@ -1824,28 +1929,35 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "SELECT * FROM customer"; - let headers = build_headers(&[ + let headers = Arc::new(build_headers(&[ ("session_nation".to_string(), Some("1".to_string())), ("session_user".to_string(), Some("'Peko'".to_string())), - ]); + ])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Peko'" ); // expect ignore the rule because session_user is optional without default value - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); // expect error because session_user is required - let headers = - build_headers(&[("session_user".to_string(), Some("'Peko'".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'Peko'".to_string()), + )])); match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await { @@ -1855,7 +1967,7 @@ mod test { @r" ModelAnalyzeRule caused by - Error during planning: Row level access control property session_nation is required, but not found in headers + Error during planning: session property session_nation is required, but not found in headers " ) } @@ -1888,17 +2000,22 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "SELECT * FROM customer"; - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" ); assert_snapshot!( - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 3" ); @@ -1919,16 +2036,21 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" ); assert_snapshot!( - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); @@ -1955,24 +2077,31 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); // the rule is expected to be skipped because the optional property is None without default value - let headers = - build_headers(&[("session_user".to_string(), Some("'Peko'".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'Peko'".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); assert_snapshot!( - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); @@ -2003,24 +2132,31 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1 AND customer.c_name = 'Gura'" ); // the rule is expected to be skipped because the optional property is None without default value - let headers = - build_headers(&[("session_user".to_string(), Some("'Peko'".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'Peko'".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); assert_snapshot!( - transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], HashMap::new(), sql) + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::new()), sql) .await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer" ); @@ -2076,9 +2212,14 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let headers = - build_headers(&[("session_user".to_string(), Some("'Gura'".to_string()))]); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("'Gura'".to_string()), + )])); let sql = "SELECT * FROM orders"; assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers.clone(), sql).await?, @@ -2091,9 +2232,6 @@ mod test { @"SELECT orders.o_orderkey, orders.o_custkey, orders.customer_name FROM (SELECT __relation__1.c_name AS customer_name, __relation__1.o_custkey, __relation__1.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name, orders.o_custkey, orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer) AS customer RIGHT JOIN (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders ON customer.c_custkey = orders.o_custkey) AS __relation__1) AS orders WHERE orders.o_orderkey > 10 AND orders.customer_name = 'Gura'" ); - // TODO: the rlac rule should be applied for the model used by the calculated field - // both to_one or to_many relationship should be supported - // let manifest = ManifestBuilder::new() .catalog("wren") .schema("test") @@ -2160,17 +2298,24 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let headers = - build_headers(&[("session_nation".to_string(), Some("1".to_string()))]); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); + let headers = Arc::new(build_headers(&[( + "session_nation".to_string(), + Some("1".to_string()), + )])); let sql = "SELECT customer_name FROM orders"; // test custoer model used by customer_name should be filtered by nation rule. assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT orders.customer_name FROM (SELECT __relation__1.c_name AS customer_name FROM (SELECT customer.c_custkey, customer.c_name, orders.o_custkey, orders.o_orderkey FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name, customer.c_nationkey FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1) AS customer RIGHT JOIN (SELECT __source.o_custkey AS o_custkey, __source.o_orderkey AS o_orderkey FROM orders AS __source) AS orders ON customer.c_custkey = orders.o_custkey) AS __relation__1) AS orders" ); - let headers = - build_headers(&[("session_user".to_string(), Some("1".to_string()))]); + let headers = Arc::new(build_headers(&[( + "session_user".to_string(), + Some("1".to_string()), + )])); let sql = "SELECT totalprice FROM customer"; // test orders model used by totalprice should be filtered by user rule. assert_snapshot!( @@ -2180,6 +2325,231 @@ mod test { Ok(()) } + #[tokio::test] + async fn test_clac_with_required_properties() -> Result<()> { + let ctx = SessionContext::new(); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column( + ColumnBuilder::new("c_name", "string") + .column_level_access_control( + "cls rule", + vec![SessionProperty::new_required("session_level")], + ColumnLevelOperator::Equals, + "1", + ) + .build(), + ) + .build(), + ) + .build(); + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("1".to_string()), + )])); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let sql = "SELECT * FROM customer"; + + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer" + ); + + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("0".to_string()), + )])); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" + ); + + let headers = Arc::new(HashMap::default()); + match AnalyzedWrenMDL::analyze(manifest, headers.clone()) { + Err(e) => { + assert_snapshot!( + e.to_string(), + @"Error during planning: session property session_level is required, but not found in headers" + ) + } + _ => panic!("Expected error"), + } + + Ok(()) + } + + #[tokio::test] + async fn test_clac_with_optional_properties() -> Result<()> { + let ctx = SessionContext::new(); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column( + ColumnBuilder::new("c_name", "string") + .column_level_access_control( + "cls rule", + vec![SessionProperty::new_optional( + "session_level", + Some("2".to_string()), + )], + ColumnLevelOperator::Equals, + "1", + ) + .build(), + ) + .build(), + ) + .build(); + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("1".to_string()), + )])); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let sql = "SELECT * FROM customer"; + + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer" + ); + + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("0".to_string()), + )])); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" + ); + + // test the rule is applied the default value if the optional property is None + let headers = Arc::new(HashMap::default()); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey FROM (SELECT customer.c_custkey FROM (SELECT __source.c_custkey AS c_custkey FROM customer AS __source) AS customer) AS customer" + ); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column( + ColumnBuilder::new("c_name", "string") + .column_level_access_control( + "cls rule", + vec![SessionProperty::new_optional( + "session_level", + None, + )], + ColumnLevelOperator::Equals, + "1", + ) + .build(), + ) + .build(), + ) + .build(); + let sql = "SELECT * FROM customer"; + + // test the rule is skipped when the optional property is None + let headers = Arc::new(HashMap::default()); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer" + ); + Ok(()) + } + + #[tokio::test] + async fn test_clac_on_calculated_field() -> Result<()> { + let ctx = SessionContext::new(); + + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("customer") + .table_reference("customer") + .column(ColumnBuilder::new("c_custkey", "int").build()) + .column( + ColumnBuilder::new("c_name", "string") + .column_level_access_control( + "cls rule", + vec![SessionProperty::new_required("session_level")], + ColumnLevelOperator::Equals, + "1", + ) + .build(), + ) + .column( + ColumnBuilder::new_calculated("c_name_upper", "string") + .expression("upper(c_name)") + .build(), + ) + .build(), + ) + .build(); + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("1".to_string()), + )])); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + let sql = "SELECT c_name_upper FROM customer"; + + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, + @"SELECT customer.c_name_upper FROM (SELECT upper(customer.c_name) AS c_name_upper FROM (SELECT __source.c_name AS c_name FROM customer AS __source) AS customer) AS customer" + ); + + let headers = Arc::new(build_headers(&[( + "session_level".to_string(), + Some("0".to_string()), + )])); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(manifest.clone(), headers.clone())?); + + match transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql) + .await + { + Err(e) => { + assert_snapshot!( + e.to_string(), + @r" + ModelAnalyzeRule + caused by + Schema error: No field named c_name. Valid fields are customer.c_custkey. + " + ) + } + _ => panic!("Expected error"), + } + Ok(()) + } + #[tokio::test] async fn test_rlac_case_insensitive() -> Result<()> { let ctx = SessionContext::new(); @@ -2201,10 +2571,15 @@ mod test { .build(), ) .build(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "SELECT * FROM customer"; - let headers = - build_headers(&[("SESSION_NATION".to_string(), Some("1".to_string()))]); + let headers = Arc::new(build_headers(&[( + "SESSION_NATION".to_string(), + Some("1".to_string()), + )])); assert_snapshot!( transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], headers, sql).await?, @"SELECT customer.c_nationkey, customer.c_name FROM (SELECT customer.c_name, customer.c_nationkey FROM (SELECT __source.c_name AS c_name, __source.c_nationkey AS c_nationkey FROM customer AS __source) AS customer) AS customer WHERE customer.c_nationkey = 1" diff --git a/wren-core/core/src/mdl/utils.rs b/wren-core/core/src/mdl/utils.rs index 560eeca18..6b762e606 100644 --- a/wren-core/core/src/mdl/utils.rs +++ b/wren-core/core/src/mdl/utils.rs @@ -43,7 +43,7 @@ pub fn collect_identifiers(expr: &str) -> Result> { let statement = parsed[0].clone(); let mut visited: BTreeSet = BTreeSet::new(); - visit_expressions(&statement, |expr| { + let _ = visit_expressions(&statement, |expr| { match expr { Identifier(id) => { visited.insert(Column::from(quoted(&id.value))); @@ -114,7 +114,7 @@ pub fn create_wren_calculated_field_expr( &expr, session_state.config_options().sql_parser.dialect.as_str(), )?; - visit_expressions_mut(&mut expr, |e| { + let _ = visit_expressions_mut(&mut expr, |e| { if let CompoundIdentifier(ids) = e { let name_size = ids.len(); if name_size > 2 { @@ -181,7 +181,7 @@ fn qualified_expr( expr, session_state.config_options().sql_parser.dialect.as_str(), )?; - visit_expressions_mut(&mut expr, |e| { + let _ = visit_expressions_mut(&mut expr, |e| { if let Identifier(id) = e { if let Ok((Some(qualifier), _)) = schema.qualified_field_with_unqualified_name(&id.value) @@ -245,7 +245,7 @@ pub fn to_remote_field( fn collect_columns(expr: datafusion::logical_expr::sqlparser::ast::Expr) -> Vec { let mut visited = vec![]; - visit_expressions(&expr, |e| { + let _ = visit_expressions(&expr, |e| { if let CompoundIdentifier(ids) = e { ids.iter().cloned().for_each(|id| visited.push(id)); } else if let Identifier(id) = e { @@ -258,6 +258,7 @@ fn collect_columns(expr: datafusion::logical_expr::sqlparser::ast::Expr) -> Vec< #[cfg(test)] mod tests { + use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::Arc; @@ -277,7 +278,8 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let ctx = SessionContext::new(); let column_rf = analyzed_mdl .wren_mdl @@ -305,7 +307,8 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let ctx = SessionContext::new(); let column_rf = analyzed_mdl .wren_mdl @@ -354,7 +357,8 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let ctx = SessionContext::new(); let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); let expr = super::create_wren_expr_for_model( @@ -374,7 +378,8 @@ mod tests { .collect(); let mdl_json = fs::read_to_string(test_data.as_path()).unwrap(); let mdl = serde_json::from_str::(&mdl_json).unwrap(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(mdl)?); + let analyzed_mdl = + Arc::new(AnalyzedWrenMDL::analyze(mdl, Arc::new(HashMap::default()))?); let ctx = SessionContext::new(); let model = analyzed_mdl.wren_mdl().get_model("customer").unwrap(); let expr = super::create_remote_expr_for_model( diff --git a/wren-core/wren-example/examples/calculation-invoke-calculation.rs b/wren-core/wren-example/examples/calculation-invoke-calculation.rs index f593fe75b..da12a7c23 100644 --- a/wren-core/wren-example/examples/calculation-invoke-calculation.rs +++ b/wren-core/wren-example/examples/calculation-invoke-calculation.rs @@ -81,7 +81,7 @@ async fn main() -> Result<()> { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + HashMap::new().into(), "select totalprice from wrenai.public.customers", ) .await @@ -107,7 +107,7 @@ async fn main() -> Result<()> { &ctx, Arc::clone(&analyzed_mdl), &[], - HashMap::new(), + HashMap::new().into(), "select customer_state_cf from wrenai.public.order_items", ) .await diff --git a/wren-core/wren-example/examples/datafusion-apply.rs b/wren-core/wren-example/examples/datafusion-apply.rs index c4fedc1b5..9f95833a3 100644 --- a/wren-core/wren-example/examples/datafusion-apply.rs +++ b/wren-core/wren-example/examples/datafusion-apply.rs @@ -78,8 +78,8 @@ async fn main() -> Result<()> { // TODO: there're some issue for optimize rules // let ctx = create_ctx_with_mdl(&ctx, analyzed_mdl).await?; let sql = "select * from wrenai.public.order_items"; - let sql = - transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new(), sql).await?; + let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], HashMap::new().into(), sql) + .await?; println!("Wren engine generated SQL: \n{}", sql); // create a plan to run a SQL query let df = match ctx.sql(&sql).await { diff --git a/wren-core/wren-example/examples/plan-sql.rs b/wren-core/wren-example/examples/plan-sql.rs index f11f80b34..71bf9a8a9 100644 --- a/wren-core/wren-example/examples/plan-sql.rs +++ b/wren-core/wren-example/examples/plan-sql.rs @@ -10,7 +10,10 @@ use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; #[tokio::main] async fn main() -> datafusion::common::Result<()> { let manifest = init_manifest(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select customer_state from wrenai.public.orders_model"; println!("Original SQL: \n{}", sql); @@ -18,7 +21,7 @@ async fn main() -> datafusion::common::Result<()> { &SessionContext::new(), analyzed_mdl, &[], - HashMap::new(), + HashMap::new().into(), sql, ) .await?; diff --git a/wren-core/wren-example/examples/row-level-access-control.rs b/wren-core/wren-example/examples/row-level-access-control.rs index 523bf21b1..32d23f5e9 100644 --- a/wren-core/wren-example/examples/row-level-access-control.rs +++ b/wren-core/wren-example/examples/row-level-access-control.rs @@ -111,7 +111,8 @@ async fn main() -> datafusion::common::Result<()> { ); let sql = "select * from wren.test.documents"; - let sql = transform_sql_with_ctx(&ctx, analyzed_mdl, &[], properties, sql).await?; + let sql = + transform_sql_with_ctx(&ctx, analyzed_mdl, &[], properties.into(), sql).await?; let df = match ctx.sql(&sql).await { Ok(df) => df, Err(e) => { diff --git a/wren-core/wren-example/examples/view.rs b/wren-core/wren-example/examples/view.rs index 7b763bde7..bb649cfb9 100644 --- a/wren-core/wren-example/examples/view.rs +++ b/wren-core/wren-example/examples/view.rs @@ -12,7 +12,10 @@ use wren_core::mdl::{transform_sql_with_ctx, AnalyzedWrenMDL}; #[tokio::main] async fn main() -> datafusion::common::Result<()> { let manifest = init_manifest(); - let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + )?); let sql = "select * from wrenai.public.customers_view"; println!("Original SQL: \n{}", sql); @@ -20,7 +23,7 @@ async fn main() -> datafusion::common::Result<()> { &SessionContext::new(), analyzed_mdl, &[], - HashMap::new(), + HashMap::new().into(), sql, ) .await?;