Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):

class ValidateDTO(BaseModel):
manifest_str: str = manifest_str_field
parameters: dict[str, str]
parameters: dict
connection_info: ConnectionInfo = connection_info_field


Expand Down
50 changes: 48 additions & 2 deletions ibis-server/app/model/validator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
from __future__ import annotations

from wren_core import (
RowLevelAccessControl,
SessionProperty,
to_manifest,
validate_rlac_rule,
)

from app.mdl.rewriter import Rewriter
from app.model import NotFoundError, UnprocessableEntityError
from app.model.connector import Connector
from app.util import base64_to_dict

rules = ["column_is_valid", "relationship_is_valid"]
rules = ["column_is_valid", "relationship_is_valid", "rlac_condition_syntax_is_valid"]


class Validator:
def __init__(self, connector: Connector, rewriter: Rewriter):
self.connector = connector
self.rewriter = rewriter

async def validate(self, rule: str, parameters: dict[str, str], manifest_str: str):
async def validate(self, rule: str, parameters: dict, manifest_str: str):
if rule not in rules:
raise RuleNotFoundError(rule)
try:
Expand Down Expand Up @@ -144,6 +151,45 @@ def format_result(result):
except Exception as e:
raise ValidationError(f"Exception: {type(e)}, message: {e!s}")

async def _validate_rlac_condition_syntax_is_valid(
self, parameters: dict, manifest_str: str
):
if parameters.get("modelName") is None:
raise MissingRequiredParameterError("modelName")
if parameters.get("requiredProperties") is None:
raise MissingRequiredParameterError("requiredProperties")
if parameters.get("condition") is None:
raise MissingRequiredParameterError("condition")

model_name = parameters.get("modelName")
required_properties = parameters.get("requiredProperties")
condition = parameters.get("condition")

required_properties = [
SessionProperty(
name=prop["name"],
required=bool(prop["required"]),
default_expr=prop.get("defaultExpr", None),
)
for prop in required_properties
]

rlac = RowLevelAccessControl(
name="rlac_validation",
required_properties=required_properties,
condition=condition,
)

manifest = to_manifest(manifest_str)
model = manifest.get_model(model_name)
if model is None:
raise ValueError(f"Model {model_name} not found in manifest")

try:
validate_rlac_rule(rlac, model)
except Exception as e:
raise ValidationError(e)

def _get_model(self, manifest, model_name):
models = list(filter(lambda m: m["name"] == model_name, manifest["models"]))
if len(models) == 0:
Expand Down
57 changes: 57 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,60 @@ async def test_validate_rule_column_is_valid_without_one_parameter(
)
assert response.status_code == 422
assert response.text == "Missing required parameter: `modelName`"


async def test_validate_rlac_condition_syntax_is_valid(
client, manifest_str, connection_info
):
response = await client.post(
url=f"{base_url}/validate/rlac_condition_syntax_is_valid",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"parameters": {
"modelName": "orders",
"requiredProperties": [
{"name": "session_order", "required": "false"},
],
"condition": "@session_order = o_orderkey",
},
},
)
assert response.status_code == 204

response = await client.post(
url=f"{base_url}/validate/rlac_condition_syntax_is_valid",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"parameters": {
"modelName": "orders",
"requiredProperties": [
{"name": "session_order", "required": False},
],
"condition": "@session_order = o_orderkey",
},
},
)
assert response.status_code == 204

response = await client.post(
url=f"{base_url}/validate/rlac_condition_syntax_is_valid",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"parameters": {
"modelName": "orders",
"requiredProperties": [
{"name": "session_order", "required": "false"},
],
"condition": "@session_not_found = o_orderkey",
},
},
)

assert response.status_code == 422
assert (
response.text
== "Error during planning: The session property @session_not_found is used, but not found in the session properties"
)
38 changes: 37 additions & 1 deletion wren-core-base/src/mdl/py_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

#[cfg(feature = "python-binding")]
mod manifest_python_impl {
use crate::mdl::manifest::{Manifest, Model};
use crate::mdl::manifest::{Manifest, Model, RowLevelAccessControl, SessionProperty};
use crate::mdl::DataSource;
use pyo3::{pymethods, PyResult};
use std::sync::Arc;
Expand Down Expand Up @@ -49,6 +49,16 @@ mod manifest_python_impl {
fn data_source(&self) -> PyResult<Option<DataSource>> {
Ok(self.data_source)
}

fn get_model(&self, name: &str) -> PyResult<Option<Model>> {
let model = self
.models
.iter()
.find(|m| m.name == name)
.cloned()
.map(Arc::unwrap_or_clone);
Ok(model)
}
}

#[pymethods]
Expand All @@ -58,4 +68,30 @@ mod manifest_python_impl {
Ok(self.name.clone())
}
}

#[pymethods]
impl SessionProperty {
#[new]
#[pyo3(signature = (name, required = false, default_expr = None))]
fn new(name: String, required: bool, default_expr: Option<String>) -> Self {
Self {
name,
required,
default_expr,
}
}
}

#[pymethods]
impl RowLevelAccessControl {
#[new]
#[pyo3(signature = (name, condition, required_properties = vec![]))]
fn new(name: String, condition: String, required_properties: Vec<SessionProperty>) -> Self {
Self {
name,
condition,
required_properties,
}
}
}
}
2 changes: 1 addition & 1 deletion wren-core-py/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl From<serde_json::Error> for CoreError {

impl From<wren_core::DataFusionError> for CoreError {
fn from(err: wren_core::DataFusionError) -> Self {
CoreError::new(&format!("DataFusion error: {}", err))
CoreError::new(err.to_string().as_str())
}
}

Expand Down
4 changes: 2 additions & 2 deletions wren-core-py/src/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ fn resolve_used_table_names(mdl: &WrenMDL, sql: &str) -> Result<Vec<String>, Cor
tables
.iter()
.filter(|t| {
t.catalog().map_or(true, |catalog| catalog == mdl.catalog())
&& t.schema().map_or(true, |schema| schema == mdl.schema())
t.catalog().is_none_or(|catalog| catalog == mdl.catalog())
&& t.schema().is_none_or(|schema| schema == mdl.schema())
})
.map(|t| t.table().to_string())
.collect()
Expand Down
6 changes: 6 additions & 0 deletions wren-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod errors;
mod extractor;
mod manifest;
pub mod remote_functions;
mod validation;

#[pymodule]
#[pyo3(name = "wren_core")]
Expand All @@ -15,7 +16,12 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<context::PySessionContext>()?;
m.add_class::<PyRemoteFunction>()?;
m.add_class::<manifest::Manifest>()?;
m.add_class::<manifest::Model>()?;
m.add_class::<manifest::RowLevelAccessControl>()?;
m.add_class::<manifest::SessionProperty>()?;
m.add_class::<extractor::PyManifestExtractor>()?;
m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?;
m.add_function(wrap_pyfunction!(manifest::to_manifest, m)?)?;
m.add_function(wrap_pyfunction!(validation::validate_rlac_rule, m)?)?;
Ok(())
}
1 change: 1 addition & 0 deletions wren-core-py/src/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub fn to_json_base64(mdl: Manifest) -> Result<String, CoreError> {
Ok(mdl_base64)
}

#[pyfunction]
/// Convert a base64 encoded JSON string to a manifest object.
pub fn to_manifest(mdl_base64: &str) -> Result<Manifest, CoreError> {
let decoded_bytes = BASE64_STANDARD.decode(mdl_base64)?;
Expand Down
13 changes: 13 additions & 0 deletions wren-core-py/src/validation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use pyo3::pyfunction;
use wren_core_base::mdl::{Model, RowLevelAccessControl};

use crate::errors::CoreError;

#[pyfunction]
pub fn validate_rlac_rule(
rule: &RowLevelAccessControl,
model: &Model,
) -> Result<(), CoreError> {
wren_core::logical_plan::analyze::access_control::validate_rlac_rule(rule, model)?;
Ok(())
}
36 changes: 36 additions & 0 deletions wren-core-py/tests/test_modeling_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import pytest
from wren_core import (
ManifestExtractor,
RowLevelAccessControl,
SessionContext,
SessionProperty,
to_json_base64,
to_manifest,
validate_rlac_rule,
)

manifest = {
Expand Down Expand Up @@ -298,3 +302,35 @@ def test_rlac():
rewritten_sql
== "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM main.customer AS __source) AS customer) AS customer WHERE customer.c_name = 'test_user'"
)


def test_validate_rlac_rule():
manifest = to_manifest(manifest_str)
model = manifest.get_model("customer")
if model is None:
raise ValueError("Model customer not found in manifest")
rlac = RowLevelAccessControl(
name="test",
required_properties=[
SessionProperty(
name="session_user",
required=False,
)
],
condition="c_name = @session_user",
)

validate_rlac_rule(rlac, model)

rlac = RowLevelAccessControl(
name="test",
required_properties=[],
condition="c_name = @session_user",
)

with pytest.raises(Exception) as e:
validate_rlac_rule(rlac, model)
assert (
str(e.value)
== "Exception: DataFusion error: Error during planning: The session property @session_user is used, but not found in the session properties"
)
Loading