diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 47b3a4d1e..af4b46964 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -357,7 +357,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo): class ValidateDTO(BaseModel): manifest_str: str = manifest_str_field - parameters: dict[str, str] + parameters: dict connection_info: ConnectionInfo = connection_info_field diff --git a/ibis-server/app/model/validator.py b/ibis-server/app/model/validator.py index 5a3c4e931..30ecbb228 100644 --- a/ibis-server/app/model/validator.py +++ b/ibis-server/app/model/validator.py @@ -1,11 +1,18 @@ from __future__ import annotations +from wren_core import ( + RowLevelAccessControl, + SessionProperty, + to_manifest, + validate_rlac_rule, +) + from app.mdl.rewriter import Rewriter from app.model import NotFoundError, UnprocessableEntityError from app.model.connector import Connector from app.util import base64_to_dict -rules = ["column_is_valid", "relationship_is_valid"] +rules = ["column_is_valid", "relationship_is_valid", "rlac_condition_syntax_is_valid"] class Validator: @@ -13,7 +20,7 @@ def __init__(self, connector: Connector, rewriter: Rewriter): self.connector = connector self.rewriter = rewriter - async def validate(self, rule: str, parameters: dict[str, str], manifest_str: str): + async def validate(self, rule: str, parameters: dict, manifest_str: str): if rule not in rules: raise RuleNotFoundError(rule) try: @@ -144,6 +151,45 @@ def format_result(result): except Exception as e: raise ValidationError(f"Exception: {type(e)}, message: {e!s}") + async def _validate_rlac_condition_syntax_is_valid( + self, parameters: dict, manifest_str: str + ): + if parameters.get("modelName") is None: + raise MissingRequiredParameterError("modelName") + if parameters.get("requiredProperties") is None: + raise MissingRequiredParameterError("requiredProperties") + if parameters.get("condition") is None: + raise MissingRequiredParameterError("condition") + + model_name = parameters.get("modelName") + required_properties = parameters.get("requiredProperties") + condition = parameters.get("condition") + + required_properties = [ + SessionProperty( + name=prop["name"], + required=bool(prop["required"]), + default_expr=prop.get("defaultExpr", None), + ) + for prop in required_properties + ] + + rlac = RowLevelAccessControl( + name="rlac_validation", + required_properties=required_properties, + condition=condition, + ) + + manifest = to_manifest(manifest_str) + model = manifest.get_model(model_name) + if model is None: + raise ValueError(f"Model {model_name} not found in manifest") + + try: + validate_rlac_rule(rlac, model) + except Exception as e: + raise ValidationError(e) + def _get_model(self, manifest, model_name): models = list(filter(lambda m: m["name"] == model_name, manifest["models"])) if len(models) == 0: diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_validate.py b/ibis-server/tests/routers/v3/connector/postgres/test_validate.py index 126095a7c..8cd426cc7 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_validate.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_validate.py @@ -119,3 +119,60 @@ async def test_validate_rule_column_is_valid_without_one_parameter( ) assert response.status_code == 422 assert response.text == "Missing required parameter: `modelName`" + + +async def test_validate_rlac_condition_syntax_is_valid( + client, manifest_str, connection_info +): + response = await client.post( + url=f"{base_url}/validate/rlac_condition_syntax_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": { + "modelName": "orders", + "requiredProperties": [ + {"name": "session_order", "required": "false"}, + ], + "condition": "@session_order = o_orderkey", + }, + }, + ) + assert response.status_code == 204 + + response = await client.post( + url=f"{base_url}/validate/rlac_condition_syntax_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": { + "modelName": "orders", + "requiredProperties": [ + {"name": "session_order", "required": False}, + ], + "condition": "@session_order = o_orderkey", + }, + }, + ) + assert response.status_code == 204 + + response = await client.post( + url=f"{base_url}/validate/rlac_condition_syntax_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": { + "modelName": "orders", + "requiredProperties": [ + {"name": "session_order", "required": "false"}, + ], + "condition": "@session_not_found = o_orderkey", + }, + }, + ) + + assert response.status_code == 422 + assert ( + response.text + == "Error during planning: The session property @session_not_found is used, but not found in the session properties" + ) diff --git a/wren-core-base/src/mdl/py_method.rs b/wren-core-base/src/mdl/py_method.rs index ef921aaee..a5e716fa3 100644 --- a/wren-core-base/src/mdl/py_method.rs +++ b/wren-core-base/src/mdl/py_method.rs @@ -19,7 +19,7 @@ #[cfg(feature = "python-binding")] mod manifest_python_impl { - use crate::mdl::manifest::{Manifest, Model}; + use crate::mdl::manifest::{Manifest, Model, RowLevelAccessControl, SessionProperty}; use crate::mdl::DataSource; use pyo3::{pymethods, PyResult}; use std::sync::Arc; @@ -49,6 +49,16 @@ mod manifest_python_impl { fn data_source(&self) -> PyResult> { Ok(self.data_source) } + + fn get_model(&self, name: &str) -> PyResult> { + let model = self + .models + .iter() + .find(|m| m.name == name) + .cloned() + .map(Arc::unwrap_or_clone); + Ok(model) + } } #[pymethods] @@ -58,4 +68,30 @@ mod manifest_python_impl { Ok(self.name.clone()) } } + + #[pymethods] + impl SessionProperty { + #[new] + #[pyo3(signature = (name, required = false, default_expr = None))] + fn new(name: String, required: bool, default_expr: Option) -> Self { + Self { + name, + required, + default_expr, + } + } + } + + #[pymethods] + impl RowLevelAccessControl { + #[new] + #[pyo3(signature = (name, condition, required_properties = vec![]))] + fn new(name: String, condition: String, required_properties: Vec) -> Self { + Self { + name, + condition, + required_properties, + } + } + } } diff --git a/wren-core-py/src/errors.rs b/wren-core-py/src/errors.rs index f5732c276..bd65d0360 100644 --- a/wren-core-py/src/errors.rs +++ b/wren-core-py/src/errors.rs @@ -51,7 +51,7 @@ impl From for CoreError { impl From for CoreError { fn from(err: wren_core::DataFusionError) -> Self { - CoreError::new(&format!("DataFusion error: {}", err)) + CoreError::new(err.to_string().as_str()) } } diff --git a/wren-core-py/src/extractor.rs b/wren-core-py/src/extractor.rs index ca66071ea..469b51bbd 100644 --- a/wren-core-py/src/extractor.rs +++ b/wren-core-py/src/extractor.rs @@ -56,8 +56,8 @@ fn resolve_used_table_names(mdl: &WrenMDL, sql: &str) -> Result, Cor tables .iter() .filter(|t| { - t.catalog().map_or(true, |catalog| catalog == mdl.catalog()) - && t.schema().map_or(true, |schema| schema == mdl.schema()) + t.catalog().is_none_or(|catalog| catalog == mdl.catalog()) + && t.schema().is_none_or(|schema| schema == mdl.schema()) }) .map(|t| t.table().to_string()) .collect() diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index af7f1fdf1..65687692e 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -7,6 +7,7 @@ mod errors; mod extractor; mod manifest; pub mod remote_functions; +mod validation; #[pymodule] #[pyo3(name = "wren_core")] @@ -15,7 +16,12 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?; + m.add_function(wrap_pyfunction!(manifest::to_manifest, m)?)?; + m.add_function(wrap_pyfunction!(validation::validate_rlac_rule, m)?)?; Ok(()) } diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index b1fe2e975..771f3d3ab 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -13,6 +13,7 @@ pub fn to_json_base64(mdl: Manifest) -> Result { Ok(mdl_base64) } +#[pyfunction] /// Convert a base64 encoded JSON string to a manifest object. pub fn to_manifest(mdl_base64: &str) -> Result { let decoded_bytes = BASE64_STANDARD.decode(mdl_base64)?; diff --git a/wren-core-py/src/validation.rs b/wren-core-py/src/validation.rs new file mode 100644 index 000000000..0a3d59185 --- /dev/null +++ b/wren-core-py/src/validation.rs @@ -0,0 +1,13 @@ +use pyo3::pyfunction; +use wren_core_base::mdl::{Model, RowLevelAccessControl}; + +use crate::errors::CoreError; + +#[pyfunction] +pub fn validate_rlac_rule( + rule: &RowLevelAccessControl, + model: &Model, +) -> Result<(), CoreError> { + wren_core::logical_plan::analyze::access_control::validate_rlac_rule(rule, model)?; + Ok(()) +} diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index c6f116367..1d7aeea9e 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -5,8 +5,12 @@ import pytest from wren_core import ( ManifestExtractor, + RowLevelAccessControl, SessionContext, + SessionProperty, to_json_base64, + to_manifest, + validate_rlac_rule, ) manifest = { @@ -298,3 +302,35 @@ def test_rlac(): rewritten_sql == "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM main.customer AS __source) AS customer) AS customer WHERE customer.c_name = 'test_user'" ) + + +def test_validate_rlac_rule(): + manifest = to_manifest(manifest_str) + model = manifest.get_model("customer") + if model is None: + raise ValueError("Model customer not found in manifest") + rlac = RowLevelAccessControl( + name="test", + required_properties=[ + SessionProperty( + name="session_user", + required=False, + ) + ], + condition="c_name = @session_user", + ) + + validate_rlac_rule(rlac, model) + + rlac = RowLevelAccessControl( + name="test", + required_properties=[], + condition="c_name = @session_user", + ) + + with pytest.raises(Exception) as e: + validate_rlac_rule(rlac, model) + assert ( + str(e.value) + == "Exception: DataFusion error: Error during planning: The session property @session_user is used, but not found in the session properties" + ) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index 64ecb13dd..16198924c 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -75,6 +75,40 @@ pub fn collect_condition( )) } +/// Validate the definition of row level access control rules. +/// Check if the syntax of the condition is valid. +/// Check if the properties used in the condition are defined in the session properties. +#[allow(dead_code)] +pub fn validate_rlac_rule(rule: &RowLevelAccessControl, model: &Model) -> Result<()> { + let RowLevelAccessControl { + condition, + required_properties, + .. + } = rule; + let (_, session_properties) = collect_condition(model, condition)?; + + let required_properties: Vec<_> = required_properties + .iter() + .map(|property| property.name.to_lowercase()) + .collect(); + + let missed_properties: Vec<_> = session_properties + .iter() + .filter(|property| !required_properties.contains(property)) + .collect(); + if !missed_properties.is_empty() { + return plan_err!( + "The session property {} is used, but not found in the session properties", + missed_properties + .iter() + .map(|property| format!("@{}", property)) + .collect::>() + .join(", ") + ); + } + Ok(()) +} + /// Build the filter expression for the row level access control rule. pub fn build_filter_expression( session_state: &SessionStateRef, @@ -261,7 +295,7 @@ mod test { collect_condition, validate_rule, }; - use super::build_filter_expression; + use super::{build_filter_expression, validate_rlac_rule}; #[test] pub fn test_collect_condition() -> Result<()> { @@ -652,4 +686,72 @@ mod test { } Ok(()) } + + #[test] + pub fn test_validate_rlac_rule() -> Result<()> { + let model = ModelBuilder::new("m1") + .column(ColumnBuilder::new("id", "int").build()) + .column(ColumnBuilder::new("name", "varchar").build()) + .build(); + + let rule = RowLevelAccessControl { + condition: "id = @session_id".to_string(), + required_properties: vec![SessionProperty::new_required("SESSION_ID")], + name: "test".to_string(), + }; + + validate_rlac_rule(&rule, &model)?; + + let rule = RowLevelAccessControl { + condition: "id = @session_id AND name = @session_name".to_string(), + required_properties: vec![ + SessionProperty::new_required("SESSION_ID"), + SessionProperty::new_required("SESSION_NAME"), + ], + name: "test".to_string(), + }; + + validate_rlac_rule(&rule, &model)?; + + let rule = RowLevelAccessControl { + condition: "id = @session_id AND name = @session_name".to_string(), + required_properties: vec![SessionProperty::new_required("SESSION_ID")], + name: "test".to_string(), + }; + + match validate_rlac_rule(&rule, &model) { + Err(error) => { + assert_snapshot!(error.message(), @"The session property @session_name is used, but not found in the session properties"); + } + _ => panic!("should be error"), + } + + let rule = RowLevelAccessControl { + condition: ",invalid".to_string(), + required_properties: vec![], + name: "test".to_string(), + }; + + match validate_rlac_rule(&rule, &model) { + Err(error) => { + assert_snapshot!(error.message(), @r#"ParserError("Expected: an expression, found: , at Line: 1, Column: 1")"#); + } + _ => panic!("should be error"), + } + + let rule = RowLevelAccessControl { + condition: "not_found = @SESSION_ID".to_string(), + required_properties: vec![SessionProperty::new_required("SESSION_ID")], + name: "test".to_string(), + }; + + match validate_rlac_rule(&rule, &model) { + Err(error) => { + assert_snapshot!(error.message(), @"The column not_found is not in the model m1"); + } + _ => panic!("should be error"), + } + + Ok(()) + } } diff --git a/wren-core/core/src/logical_plan/analyze/mod.rs b/wren-core/core/src/logical_plan/analyze/mod.rs index 483123777..e9f2374f2 100644 --- a/wren-core/core/src/logical_plan/analyze/mod.rs +++ b/wren-core/core/src/logical_plan/analyze/mod.rs @@ -1,4 +1,4 @@ -mod access_control; +pub mod access_control; pub mod expand_view; pub mod model_anlayze; pub mod model_generation;