Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
"name": "c_name_access",
"requiredProperties": [
{
"name": "session_level",
# To test the name is case insensitive
"name": "Session_level",
"required": False,
}
],
Expand Down Expand Up @@ -672,14 +673,30 @@ async def test_clac_query(client, manifest_str, connection_info):
"sql": "SELECT * FROM customer limit 1",
},
headers={
X_WREN_VARIABLE_PREFIX + "session_level": "2",
X_WREN_VARIABLE_PREFIX + "session_level": "1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert len(result["data"][0]) == 2

response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": base64_manifest_with_required_properties,
"sql": "SELECT * FROM customer limit 1",
},
headers={
X_WREN_VARIABLE_PREFIX + "session_level": "2",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert len(result["data"][0]) == 1


async def test_connection_timeout(
client, manifest_str, connection_info, connection_url
Expand Down
41 changes: 40 additions & 1 deletion wren-core-base/manifest-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,51 @@ pub fn session_property(python_binding: proc_macro::TokenStream) -> proc_macro::
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)]
#[derive(Serialize, Debug, PartialEq, Eq, Hash, Clone)]
#[serde(rename_all = "camelCase")]
pub struct SessionProperty {
pub name: String,
pub required: bool,
pub default_expr: Option<String>,
// To avoid duplicate clone for normalized name(to_lowercase), we store it here
#[serde(skip_serializing, default = "String::new")]
pub normalized_name: String,
}

impl SessionProperty {
#[cfg(not(feature = "python-binding"))]
pub fn new(name: String, required: bool, default_expr: Option<String>) -> Self {
let normalized_name = name.to_lowercase();
Self {
name,
required,
default_expr,
normalized_name,
}
}
}

impl<'de> serde::Deserialize<'de> for SessionProperty {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct SessionPropertyHelper {
name: String,
required: bool,
default_expr: Option<String>,
}

let helper = SessionPropertyHelper::deserialize(deserializer)?;
Ok(SessionProperty {
normalized_name: helper.name.to_lowercase(),
name: helper.name,
required: helper.required,
default_expr: helper.default_expr,
})
}
}
};
proc_macro::TokenStream::from(expanded)
Expand Down
23 changes: 13 additions & 10 deletions wren-core-base/src/mdl/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,18 +175,10 @@ impl ModelBuilder {

impl SessionProperty {
pub fn new_required(name: &str) -> Self {
SessionProperty {
name: name.to_string(),
required: true,
default_expr: None,
}
SessionProperty::new(name.to_string(), true, None)
}
pub fn new_optional(name: &str, default_expr: Option<String>) -> Self {
SessionProperty {
name: name.to_string(),
required: false,
default_expr,
}
SessionProperty::new(name.to_string(), false, default_expr)
}
}
pub struct ColumnBuilder {
Expand Down Expand Up @@ -845,4 +837,15 @@ mod test {
.data_source(MySQL);
assert_eq!(mdl, expected.build());
}

#[test]
fn test_session_property_roundtrip() {
let expected = SessionProperty::new_optional("session_id", Some("1".to_string()));

let json_str = serde_json::to_string(&expected).unwrap();
assert!(!json_str.contains(r#"normalizedName"#));
let actual: SessionProperty = serde_json::from_str(&json_str).unwrap();
assert_eq!(actual.normalized_name(), actual.name.to_lowercase());
assert_eq!(actual, expected)
}
}
6 changes: 6 additions & 0 deletions wren-core-base/src/mdl/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,12 @@ impl View {
}
}

impl SessionProperty {
pub fn normalized_name(&self) -> &str {
&self.normalized_name
}
}

#[cfg(test)]
mod tests {
use crate::mdl::manifest::table_reference;
Expand Down
3 changes: 2 additions & 1 deletion wren-core-base/src/mdl/py_method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ mod manifest_python_impl {
impl SessionProperty {
#[new]
#[pyo3(signature = (name, required = false, default_expr = None))]
fn new(name: String, required: bool, default_expr: Option<String>) -> Self {
pub fn new(name: String, required: bool, default_expr: Option<String>) -> Self {
Self {
normalized_name: name.to_lowercase(),
name,
required,
default_expr,
Expand Down
18 changes: 8 additions & 10 deletions wren-core/core/src/logical_plan/analyze/access_control.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ pub fn validate_rlac_rule(rule: &RowLevelAccessControl, model: &Model) -> Result

let required_properties: Vec<_> = required_properties
.iter()
.map(|property| property.name.to_lowercase())
.map(|property| property.normalized_name())
.collect();

let missed_properties: Vec<_> = session_properties
.iter()
.filter(|property| !required_properties.contains(property))
.filter(|property| !required_properties.contains(&property.as_str()))
.collect();
if !missed_properties.is_empty() {
return plan_err!(
Expand Down Expand Up @@ -144,9 +144,7 @@ pub fn build_filter_expression(
let Some(property_value) = properties.get(&property_name).or_else(|| {
required_properties
.iter()
.filter(|r| {
!r.required && r.name.eq_ignore_ascii_case(&property_name)
})
.filter(|r| !r.required && r.normalized_name().eq(&property_name))
.map(|r| &r.default_expr)
.next()
}) else {
Expand Down Expand Up @@ -257,7 +255,7 @@ pub fn validate_rule(
.iter()
.map(|property| {
if property.required {
if !is_property_present(headers, &property.name) {
if !is_property_present(headers, property) {
return plan_err!(
"session property {} is required for `{}` rule but not found in headers",
property.name,
Expand All @@ -266,7 +264,7 @@ pub fn validate_rule(
}
Ok(true)
} else {
let exist = is_property_present(headers, &property.name);
let exist = is_property_present(headers, property);
if exist
|| property
.default_expr
Expand Down Expand Up @@ -303,7 +301,7 @@ pub(crate) fn validate_clac_rule(
}

let property = &clac.required_properties[0];
let value_opt = properties.get(&property.name);
let value_opt = properties.get(property.normalized_name());

match value_opt {
Some(Some(value)) => (clac.eval(value), Some(clac.name.clone())),
Expand Down Expand Up @@ -363,10 +361,10 @@ pub(crate) fn validate_clac_rule(
/// If the property is present and not empty, return true.
fn is_property_present(
headers: &HashMap<String, Option<String>>,
property_name: &str,
property: &SessionProperty,
) -> bool {
headers
.get(&property_name.to_lowercase())
.get(property.normalized_name())
.map(|v| v.as_ref().is_some_and(|value| !value.is_empty()))
.unwrap_or(false)
}
Expand Down
6 changes: 3 additions & 3 deletions wren-core/core/src/mdl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2659,7 +2659,7 @@ mod test {
ColumnBuilder::new("c_name", "string")
.column_level_access_control(
"cls rule",
vec![SessionProperty::new_required("session_level")],
vec![SessionProperty::new_required("Session_level")],
ColumnLevelOperator::Equals,
"1",
)
Expand Down Expand Up @@ -2703,7 +2703,7 @@ mod test {
Err(e) => {
assert_snapshot!(
e.to_string(),
@"Error during planning: session property session_level is required for `cls rule` rule but not found in headers"
@"Error during planning: session property Session_level is required for `cls rule` rule but not found in headers"
)
}
_ => panic!("Expected error"),
Expand Down Expand Up @@ -3662,7 +3662,7 @@ mod test {
) -> HashMap<String, Option<String>> {
let mut headers = HashMap::new();
for (key, value) in field {
headers.insert(key.clone(), value.clone());
headers.insert(key.to_lowercase(), value.clone());
}
headers
}
Expand Down