Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ibis-server/app/mdl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

@cache
def get_session_context(
manifest_str: str | None, function_path: str
manifest_str: str | None, function_path: str, properties: frozenset | None = None
) -> wren_core.SessionContext:
return wren_core.SessionContext(manifest_str, function_path)
return wren_core.SessionContext(manifest_str, function_path, properties)


def get_manifest_extractor(manifest_str: str) -> wren_core.ManifestExtractor:
Expand Down
12 changes: 8 additions & 4 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,31 @@ async def rewrite(
self, manifest_str: str, sql: str, properties: dict | None = None
) -> str:
try:
session_context = get_session_context(manifest_str, self.function_path)
processed_properties = self.get_session_properties(properties)
session_context = get_session_context(
manifest_str, self.function_path, processed_properties
)
return await to_thread.run_sync(
session_context.transform_sql,
sql,
self.get_session_properties(properties),
)
except Exception as e:
raise RewriteError(str(e))

def get_session_properties(self, properties: dict) -> dict | None:
def get_session_properties(self, properties: dict) -> frozenset | None:
if properties is None:
return None
# filter the properties which name starts with "x-wren-variable-"
# and remove the prefix "x-wren-variable-"

return {
processed_properties = {
k.replace(X_WREN_VARIABLE_PREFIX, ""): v
for k, v in properties.items()
if k.startswith(X_WREN_VARIABLE_PREFIX)
}

return frozenset(processed_properties.items())

@staticmethod
def handle_extract_exception(e: Exception):
raise RewriteError(str(e))
Expand Down
63 changes: 62 additions & 1 deletion ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,21 @@
},
"columns": [
{"name": "c_custkey", "type": "integer"},
{"name": "c_name", "type": "varchar"},
{
"name": "c_name",
"type": "varchar",
"columnLevelAccessControl": {
"name": "c_name_access",
"requiredProperties": [
{
"name": "session_level",
"required": False,
}
],
"operator": "EQUALS",
"threshold": "1",
},
},
{"name": "orders", "type": "orders", "relationship": "orders_customer"},
{
"name": "sum_totalprice",
Expand Down Expand Up @@ -525,3 +539,50 @@ async def test_rlac_query(client, manifest_str, connection_info):
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "Customer#000000001"


async def test_clac_query(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM customer limit 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert len(result["data"][0]) == 3

response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM customer limit 1",
},
headers={
X_WREN_VARIABLE_PREFIX + "session_level": "1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert len(result["data"][0]) == 3

response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM customer limit 1",
},
headers={
X_WREN_VARIABLE_PREFIX + "session_level": "2",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert len(result["data"][0]) == 2
2 changes: 1 addition & 1 deletion wren-core-base/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ python-binding = ["dep:pyo3"]
default = []

[dependencies]
pyo3 = { version = "0.24.1", features = ["extension-module"], optional = true }
pyo3 = { version = "0.25.0", features = ["extension-module"], optional = true }
serde = { version = "1.0.201", features = ["derive", "rc"] }
wren-manifest-macro = { path = "manifest-macro" }
serde_json = { version = "1.0.117" }
Expand Down
29 changes: 29 additions & 0 deletions wren-core-base/manifest-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ pub fn column(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStrea
pub is_hidden: bool,
#[deprecated]
pub rls: Option<RowLevelSecurity>,
#[deprecated]
pub cls: Option<ColumnLevelSecurity>,
pub column_level_access_control: Option<Arc<ColumnLevelAccessControl>>,
}
};
proc_macro::TokenStream::from(expanded)
Expand Down Expand Up @@ -466,6 +468,7 @@ pub fn row_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro
}

#[proc_macro]
#[deprecated]
pub fn column_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
Expand All @@ -487,6 +490,32 @@ pub fn column_level_security(python_binding: proc_macro::TokenStream) -> proc_ma
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn column_level_access_control(
python_binding: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "camelCase")]
pub struct ColumnLevelAccessControl {
pub name: String,
pub required_properties: Vec<SessionProperty>,
pub operator: ColumnLevelOperator,
pub threshold: NormalizedExpr,
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn column_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
Expand Down
26 changes: 26 additions & 0 deletions wren-core-base/src/mdl/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ use crate::mdl::{
};
use std::sync::Arc;

use super::ColumnLevelAccessControl;

/// A builder for creating a Manifest
pub struct ManifestBuilder {
pub manifest: Manifest,
Expand Down Expand Up @@ -205,6 +207,7 @@ impl ColumnBuilder {
expression: None,
rls: None,
cls: None,
column_level_access_control: None,
},
}
}
Expand Down Expand Up @@ -264,6 +267,23 @@ impl ColumnBuilder {
});
self
}

pub fn column_level_access_control(
mut self,
name: &str,
required_properties: Vec<SessionProperty>,
operator: ColumnLevelOperator,
threshold: &str,
) -> Self {
self.column.column_level_access_control = Some(Arc::new(ColumnLevelAccessControl {
name: name.to_string(),
required_properties,
operator,
threshold: NormalizedExpr::new(threshold),
}));
self
}

pub fn build(self) -> Arc<Column> {
Arc::new(self.column)
}
Expand Down Expand Up @@ -437,6 +457,12 @@ mod test {
.expression("test")
.row_level_security("SESSION_STATUS", RowLevelOperator::Equals)
.column_level_security("SESSION_LEVEL", ColumnLevelOperator::Equals, "'NORMAL'")
.column_level_access_control(
"rlac",
vec![SessionProperty::new_required("session_id")],
ColumnLevelOperator::Equals,
"'NORMAL'",
)
.build();

let json_str = serde_json::to_string(&expected).unwrap();
Expand Down
20 changes: 19 additions & 1 deletion wren-core-base/src/mdl/cls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
use crate::mdl::manifest::{ColumnLevelSecurity, NormalizedExpr, NormalizedExprType};
use crate::mdl::manifest::{
ColumnLevelAccessControl, ColumnLevelSecurity, NormalizedExpr, NormalizedExprType,
};
use crate::mdl::ColumnLevelOperator;
use std::fmt::{Display, Formatter};
use std::str::FromStr;
Expand All @@ -37,6 +39,22 @@ impl ColumnLevelSecurity {
}
}

impl ColumnLevelAccessControl {
/// Evaluate the input against the column level access control.
/// If the type of the input is different from the type of the value, the result is always false except for NOT_EQUALS.
pub fn eval(&self, input: &str) -> bool {
let input_expr = NormalizedExpr::new(input);
match self.operator {
ColumnLevelOperator::Equals => input_expr.eq(&self.threshold),
ColumnLevelOperator::NotEquals => input_expr.neq(&self.threshold),
ColumnLevelOperator::GreaterThan => input_expr.gt(&self.threshold),
ColumnLevelOperator::LessThan => input_expr.lt(&self.threshold),
ColumnLevelOperator::GreaterThanOrEquals => input_expr.gte(&self.threshold),
ColumnLevelOperator::LessThanOrEquals => input_expr.lte(&self.threshold),
}
}
}

impl NormalizedExpr {
pub fn new(expr: &str) -> Self {
assert!(!expr.is_empty(), "expr is null or empty");
Expand Down
26 changes: 18 additions & 8 deletions wren-core-base/src/mdl/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ mod manifest_impl {
use crate::mdl::manifest::bool_from_int;
use crate::mdl::manifest::table_reference;
use manifest_macro::{
column, column_level_operator, column_level_security, data_source, join_type, manifest,
metric, model, normalized_expr, normalized_expr_type, relationship,
row_level_access_control, row_level_operator, row_level_security, session_property,
time_grain, time_unit, view,
column, column_level_access_control, column_level_operator, column_level_security,
data_source, join_type, manifest, metric, model, normalized_expr, normalized_expr_type,
relationship, row_level_access_control, row_level_operator, row_level_security,
session_property, time_grain, time_unit, view,
};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
Expand All @@ -47,6 +47,7 @@ mod manifest_impl {
time_grain!(false);
time_unit!(false);
row_level_access_control!(false);
column_level_access_control!(false);
session_property!(false);
row_level_security!(false);
row_level_operator!(false);
Expand All @@ -62,10 +63,10 @@ mod manifest_impl {
use crate::mdl::manifest::bool_from_int;
use crate::mdl::manifest::table_reference;
use manifest_macro::{
column, column_level_operator, column_level_security, data_source, join_type, manifest,
metric, model, normalized_expr, normalized_expr_type, relationship,
row_level_access_control, row_level_operator, row_level_security, session_property,
time_grain, time_unit, view,
column, column_level_access_control, column_level_operator, column_level_security,
data_source, join_type, manifest, metric, model, normalized_expr, normalized_expr_type,
relationship, row_level_access_control, row_level_operator, row_level_security,
session_property, time_grain, time_unit, view,
};
use pyo3::pyclass;
use serde::{Deserialize, Serialize};
Expand All @@ -86,6 +87,7 @@ mod manifest_impl {
time_unit!(true);
manifest!(true);
row_level_access_control!(true);
column_level_access_control!(true);
session_property!(true);
row_level_security!(true);
row_level_operator!(true);
Expand Down Expand Up @@ -294,6 +296,14 @@ impl Column {
pub fn expression(&self) -> Option<&str> {
self.expression.as_deref()
}

pub fn column_level_access_control(&self) -> Option<Arc<ColumnLevelAccessControl>> {
if let Some(ref cla) = &self.column_level_access_control {
Some(Arc::clone(cla))
} else {
None
}
}
}

impl Metric {
Expand Down
Loading