diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index 549da3484..5a9436469 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -78,7 +78,8 @@ "name": "c_name_access", "requiredProperties": [ { - "name": "session_level", + # To test the name is case insensitive + "name": "Session_level", "required": False, } ], @@ -672,7 +673,7 @@ async def test_clac_query(client, manifest_str, connection_info): "sql": "SELECT * FROM customer limit 1", }, headers={ - X_WREN_VARIABLE_PREFIX + "session_level": "2", + X_WREN_VARIABLE_PREFIX + "session_level": "1", }, ) assert response.status_code == 200 @@ -680,6 +681,22 @@ async def test_clac_query(client, manifest_str, connection_info): assert len(result["data"]) == 1 assert len(result["data"][0]) == 2 + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": base64_manifest_with_required_properties, + "sql": "SELECT * FROM customer limit 1", + }, + headers={ + X_WREN_VARIABLE_PREFIX + "session_level": "2", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + assert len(result["data"][0]) == 1 + async def test_connection_timeout( client, manifest_str, connection_info, connection_url diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index ac966a110..cf0700e5c 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -404,12 +404,51 @@ pub fn session_property(python_binding: proc_macro::TokenStream) -> proc_macro:: }; let expanded = quote! { #python_binding - #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone)] + #[derive(Serialize, Debug, PartialEq, Eq, Hash, Clone)] #[serde(rename_all = "camelCase")] pub struct SessionProperty { pub name: String, pub required: bool, pub default_expr: Option, + // To avoid duplicate clone for normalized name(to_lowercase), we store it here + #[serde(skip_serializing, default = "String::new")] + pub normalized_name: String, + } + + impl SessionProperty { + #[cfg(not(feature = "python-binding"))] + pub fn new(name: String, required: bool, default_expr: Option) -> Self { + let normalized_name = name.to_lowercase(); + Self { + name, + required, + default_expr, + normalized_name, + } + } + } + + impl<'de> serde::Deserialize<'de> for SessionProperty { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + #[derive(Deserialize)] + #[serde(rename_all = "camelCase")] + struct SessionPropertyHelper { + name: String, + required: bool, + default_expr: Option, + } + + let helper = SessionPropertyHelper::deserialize(deserializer)?; + Ok(SessionProperty { + normalized_name: helper.name.to_lowercase(), + name: helper.name, + required: helper.required, + default_expr: helper.default_expr, + }) + } } }; proc_macro::TokenStream::from(expanded) diff --git a/wren-core-base/src/mdl/builder.rs b/wren-core-base/src/mdl/builder.rs index 9f8eaa29c..28b7f54cb 100644 --- a/wren-core-base/src/mdl/builder.rs +++ b/wren-core-base/src/mdl/builder.rs @@ -175,18 +175,10 @@ impl ModelBuilder { impl SessionProperty { pub fn new_required(name: &str) -> Self { - SessionProperty { - name: name.to_string(), - required: true, - default_expr: None, - } + SessionProperty::new(name.to_string(), true, None) } pub fn new_optional(name: &str, default_expr: Option) -> Self { - SessionProperty { - name: name.to_string(), - required: false, - default_expr, - } + SessionProperty::new(name.to_string(), false, default_expr) } } pub struct ColumnBuilder { @@ -845,4 +837,15 @@ mod test { .data_source(MySQL); assert_eq!(mdl, expected.build()); } + + #[test] + fn test_session_property_roundtrip() { + let expected = SessionProperty::new_optional("session_id", Some("1".to_string())); + + let json_str = serde_json::to_string(&expected).unwrap(); + assert!(!json_str.contains(r#"normalizedName"#)); + let actual: SessionProperty = serde_json::from_str(&json_str).unwrap(); + assert_eq!(actual.normalized_name(), actual.name.to_lowercase()); + assert_eq!(actual, expected) + } } diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 09effc947..a8991476b 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -332,6 +332,12 @@ impl View { } } +impl SessionProperty { + pub fn normalized_name(&self) -> &str { + &self.normalized_name + } +} + #[cfg(test)] mod tests { use crate::mdl::manifest::table_reference; diff --git a/wren-core-base/src/mdl/py_method.rs b/wren-core-base/src/mdl/py_method.rs index a5e716fa3..38d91c367 100644 --- a/wren-core-base/src/mdl/py_method.rs +++ b/wren-core-base/src/mdl/py_method.rs @@ -73,8 +73,9 @@ mod manifest_python_impl { impl SessionProperty { #[new] #[pyo3(signature = (name, required = false, default_expr = None))] - fn new(name: String, required: bool, default_expr: Option) -> Self { + pub fn new(name: String, required: bool, default_expr: Option) -> Self { Self { + normalized_name: name.to_lowercase(), name, required, default_expr, diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index 8c4865425..0fc16b9ce 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -96,12 +96,12 @@ pub fn validate_rlac_rule(rule: &RowLevelAccessControl, model: &Model) -> Result let required_properties: Vec<_> = required_properties .iter() - .map(|property| property.name.to_lowercase()) + .map(|property| property.normalized_name()) .collect(); let missed_properties: Vec<_> = session_properties .iter() - .filter(|property| !required_properties.contains(property)) + .filter(|property| !required_properties.contains(&property.as_str())) .collect(); if !missed_properties.is_empty() { return plan_err!( @@ -144,9 +144,7 @@ pub fn build_filter_expression( let Some(property_value) = properties.get(&property_name).or_else(|| { required_properties .iter() - .filter(|r| { - !r.required && r.name.eq_ignore_ascii_case(&property_name) - }) + .filter(|r| !r.required && r.normalized_name().eq(&property_name)) .map(|r| &r.default_expr) .next() }) else { @@ -257,7 +255,7 @@ pub fn validate_rule( .iter() .map(|property| { if property.required { - if !is_property_present(headers, &property.name) { + if !is_property_present(headers, property) { return plan_err!( "session property {} is required for `{}` rule but not found in headers", property.name, @@ -266,7 +264,7 @@ pub fn validate_rule( } Ok(true) } else { - let exist = is_property_present(headers, &property.name); + let exist = is_property_present(headers, property); if exist || property .default_expr @@ -303,7 +301,7 @@ pub(crate) fn validate_clac_rule( } let property = &clac.required_properties[0]; - let value_opt = properties.get(&property.name); + let value_opt = properties.get(property.normalized_name()); match value_opt { Some(Some(value)) => (clac.eval(value), Some(clac.name.clone())), @@ -363,10 +361,10 @@ pub(crate) fn validate_clac_rule( /// If the property is present and not empty, return true. fn is_property_present( headers: &HashMap>, - property_name: &str, + property: &SessionProperty, ) -> bool { headers - .get(&property_name.to_lowercase()) + .get(property.normalized_name()) .map(|v| v.as_ref().is_some_and(|value| !value.is_empty())) .unwrap_or(false) } diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index dd1043c90..d5fcac426 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -2659,7 +2659,7 @@ mod test { ColumnBuilder::new("c_name", "string") .column_level_access_control( "cls rule", - vec![SessionProperty::new_required("session_level")], + vec![SessionProperty::new_required("Session_level")], ColumnLevelOperator::Equals, "1", ) @@ -2703,7 +2703,7 @@ mod test { Err(e) => { assert_snapshot!( e.to_string(), - @"Error during planning: session property session_level is required for `cls rule` rule but not found in headers" + @"Error during planning: session property Session_level is required for `cls rule` rule but not found in headers" ) } _ => panic!("Expected error"), @@ -3662,7 +3662,7 @@ mod test { ) -> HashMap> { let mut headers = HashMap::new(); for (key, value) in field { - headers.insert(key.clone(), value.clone()); + headers.insert(key.to_lowercase(), value.clone()); } headers }