diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index c2553226c..dae78dbc4 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -1,3 +1,4 @@ +import wren_core from fastapi import Request from starlette.datastructures import Headers @@ -48,9 +49,8 @@ def _filter_headers(header_string: str) -> bool: return False -def exist_wren_variables_header( - headers: Headers, -) -> bool: - if headers is None: +def is_backward_compatible(manifest_str: str) -> bool: + try: + return wren_core.is_backward_compatible(manifest_str) + except Exception: return False - return any(key.startswith(X_WREN_VARIABLE_PREFIX) for key in headers.keys()) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 01f1acb25..c03b64597 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -14,8 +14,8 @@ X_CACHE_OVERRIDE, X_CACHE_OVERRIDE_AT, X_WREN_FALLBACK_DISABLE, - exist_wren_variables_header, get_wren_headers, + is_backward_compatible, verify_query_dto, ) from app.mdl.core import get_session_context @@ -191,7 +191,7 @@ async def query( if ( java_engine_connector.client is None or is_fallback_disable - or exist_wren_variables_header(headers) + or not is_backward_compatible(dto.manifest_str) ): raise e @@ -237,7 +237,7 @@ async def dry_plan( if ( java_engine_connector.client is None or is_fallback_disable - or exist_wren_variables_header(headers) + or not is_backward_compatible(dto.manifest_str) ): raise e @@ -285,7 +285,7 @@ async def dry_plan_for_data_source( if ( java_engine_connector.client is None or is_fallback_disable - or exist_wren_variables_header(headers) + or not is_backward_compatible(dto.manifest_str) ): raise e @@ -351,7 +351,7 @@ async def validate( if ( java_engine_connector.client is None or is_fallback_disable - or exist_wren_variables_header(headers) + or not is_backward_compatible(dto.manifest_str) ): raise e @@ -446,7 +446,7 @@ async def model_substitute( if ( java_engine_connector.client is None or is_fallback_disable - or exist_wren_variables_header(headers) + or not is_backward_compatible(dto.manifest_str) ): raise e diff --git a/ibis-server/resources/function_list/postgres.csv b/ibis-server/resources/function_list/postgres.csv index 86d433a40..e3df82327 100644 --- a/ibis-server/resources/function_list/postgres.csv +++ b/ibis-server/resources/function_list/postgres.csv @@ -33,4 +33,4 @@ scalar,sign,numeric,,numeric,"Sign of number" scalar,to_json,json,,boolean,"Convert to JSON" scalar,to_number,numeric,,"text,text","Convert string to number" scalar,unistr,varchar,,text,"Postgres: Evaluate escaped Unicode characters in the argument" -scalar,pg_sleep,,,"numeric","Sleep for specified time in seconds" +scalar,pg_sleep,,,"bigint","Sleep for specified time in seconds" diff --git a/ibis-server/tests/routers/v3/connector/postgres/conftest.py b/ibis-server/tests/routers/v3/connector/postgres/conftest.py index 36b003a0d..20388e7cf 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/conftest.py +++ b/ibis-server/tests/routers/v3/connector/postgres/conftest.py @@ -5,6 +5,7 @@ import sqlalchemy from testcontainers.postgres import PostgresContainer +from app.config import get_config from tests.conftest import file_path pytestmark = pytest.mark.postgres @@ -19,6 +20,9 @@ def pytest_collection_modifyitems(items): item.add_marker(pytestmark) +function_list_path = file_path("../resources/function_list") + + @pytest.fixture(scope="module") def postgres(request) -> PostgresContainer: pg = PostgresContainer("postgres:16-alpine").start() @@ -56,3 +60,11 @@ def connection_info(postgres: PostgresContainer) -> dict[str, str]: def connection_url(connection_info: dict[str, str]): info = connection_info return f"postgres://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['database']}" + + +@pytest.fixture(autouse=True) +def set_remote_function_list_path(): + config = get_config() + config.set_remote_function_list_path(function_list_path) + yield + config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py index 0c0d8c958..93ad3360a 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py @@ -380,4 +380,92 @@ async def test_query_rlac(client, manifest_str, connection_info): }, headers={X_WREN_VARIABLE_PREFIX + "session_user": "1"}, ) + assert response.status_code == 200 + + manifest_rlac = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": {"schema": "public", "table": "orders"}, + "columns": [ + { + "name": "orderkey", + "type": "varchar", + "expression": "cast(o_orderkey as varchar)", + } + ], + "rowLevelAccessControls": [ + { + "name": "rule", + "requiredProperties": [ + { + "name": "session_user", + "required": False, + } + ], + "condition": "orderkey = @session_user", + }, + ], + } + ], + } + + manifest_rlac_str = base64.b64encode(orjson.dumps(manifest_rlac)).decode("utf-8") + + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_rlac_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + ) + assert response.status_code == 422 + + manifest_clac = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": {"schema": "public", "table": "orders"}, + "columns": [ + { + "name": "orderkey", + "type": "varchar", + "expression": "cast(o_orderkey as varchar)", + }, + { + "name": "custkey", + "type": "varchar", + "columnLevelAccessControl": { + "name": "o_custkey_access", + "requiredProperties": [ + { + "name": "session_level", + "required": False, + "defaultExpr": "2", + } + ], + "operator": "GREATER_THAN", + "threshold": "3", + }, + "expression": "cast(o_custkey as varchar)", + }, + ], + } + ], + } + + manifest_clac_str = base64.b64encode(orjson.dumps(manifest_clac)).decode("utf-8") + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_clac_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + ) assert response.status_code == 422 diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py index 27d31c5fd..b9b68eb84 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_functions.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_functions.py @@ -4,8 +4,8 @@ import pytest from app.config import get_config -from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path -from tests.routers.v3.connector.postgres.conftest import base_url +from tests.conftest import DATAFUSION_FUNCTION_COUNT +from tests.routers.v3.connector.postgres.conftest import base_url, function_list_path pytestmark = pytest.mark.functions @@ -26,22 +26,12 @@ ], } -function_list_path = file_path("../resources/function_list") - @pytest.fixture(scope="module") def manifest_str(): return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") -@pytest.fixture(autouse=True) -def set_remote_function_list_path(): - config = get_config() - config.set_remote_function_list_path(function_list_path) - yield - config.set_remote_function_list_path(None) - - async def test_function_list(client): config = get_config() diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_query.py b/ibis-server/tests/routers/v3/connector/postgres/test_query.py index 95ce6e139..c1b31abd5 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_query.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_query.py @@ -694,11 +694,11 @@ async def test_connection_timeout( }, headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second ) - assert response.status_code == 504 assert ( "Query was cancelled: canceling statement due to statement timeout" in response.text ) + assert response.status_code == 504 # test connection_url way can also timeout response = await client.post( @@ -710,11 +710,11 @@ async def test_connection_timeout( }, headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second ) - assert response.status_code == 504 assert ( "Query was cancelled: canceling statement due to statement timeout" in response.text ) + assert response.status_code == 504 async def test_format_floating(client, manifest_str, connection_info): diff --git a/wren-core-py/src/context.rs b/wren-core-py/src/context.rs index 01bd852f5..5e00f1ef7 100644 --- a/wren-core-py/src/context.rs +++ b/wren-core-py/src/context.rs @@ -253,19 +253,23 @@ impl PySessionContext { let _ = visit_statements_mut(&mut statements, |stmt| { if let Statement::Query(q) = stmt { if let Some(limit) = &q.limit { - if let Expr::Value(ValueWithSpan { value: Value::Number(n, is), .. }) = limit { + if let Expr::Value(ValueWithSpan { + value: Value::Number(n, is), + .. + }) = limit + { if let Ok(curr) = n.parse::() { if curr > pushdown { - q.limit = Some(Expr::Value(Value::Number( - pushdown.to_string(), - *is, - ).into())); + q.limit = Some(Expr::Value( + Value::Number(pushdown.to_string(), *is).into(), + )); } } } } else { - q.limit = - Some(Expr::Value(Value::Number(pushdown.to_string(), false).into())); + q.limit = Some(Expr::Value( + Value::Number(pushdown.to_string(), false).into(), + )); } } ControlFlow::<()>::Continue(()) diff --git a/wren-core-py/src/lib.rs b/wren-core-py/src/lib.rs index 65687692e..c9425d0be 100644 --- a/wren-core-py/src/lib.rs +++ b/wren-core-py/src/lib.rs @@ -23,5 +23,6 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?; m.add_function(wrap_pyfunction!(manifest::to_manifest, m)?)?; m.add_function(wrap_pyfunction!(validation::validate_rlac_rule, m)?)?; + m.add_function(wrap_pyfunction!(manifest::is_backward_compatible, m)?)?; Ok(()) } diff --git a/wren-core-py/src/manifest.rs b/wren-core-py/src/manifest.rs index 771f3d3ab..286862ddb 100644 --- a/wren-core-py/src/manifest.rs +++ b/wren-core-py/src/manifest.rs @@ -22,6 +22,24 @@ pub fn to_manifest(mdl_base64: &str) -> Result { Ok(manifest) } +/// Check if the MDL can be used by the v2 wren core. If there are any access controls rules, +/// the MDL should be used by the v3 wren core only. +#[pyfunction] +pub fn is_backward_compatible(mdl_base64: &str) -> Result { + let manifest = to_manifest(mdl_base64)?; + let ralc_exist = manifest + .models + .iter() + .all(|model| model.row_level_access_controls().is_empty()); + let clac_exist = manifest.models.iter().all(|model| { + model + .columns + .iter() + .all(|column| column.column_level_access_control().is_none()) + }); + Ok(ralc_exist && clac_exist) +} + #[cfg(test)] mod tests { use crate::manifest::{to_json_base64, to_manifest, Manifest}; diff --git a/wren-core-py/tests/test_modeling_core.py b/wren-core-py/tests/test_modeling_core.py index fe732489c..46cc95100 100644 --- a/wren-core-py/tests/test_modeling_core.py +++ b/wren-core-py/tests/test_modeling_core.py @@ -8,6 +8,7 @@ RowLevelAccessControl, SessionContext, SessionProperty, + is_backward_compatible, to_json_base64, to_manifest, validate_rlac_rule, @@ -332,7 +333,21 @@ def test_validate_rlac_rule(): required=False, ) ], - condition="c_name = @session_user", + condition="customer.c_name = @session_user", + ) + + validate_rlac_rule(rlac, model) + + # Test case insensitivity + rlac = RowLevelAccessControl( + name="test", + required_properties=[ + SessionProperty( + name="session_usEr", + required=False, + ) + ], + condition="c_name = @SEssion_user", ) validate_rlac_rule(rlac, model) @@ -371,6 +386,164 @@ def test_clac(): session_context.transform_sql(sql) except Exception as e: assert ( - str(e) - == "Permission Denied: No permission to access \"customer\".\"c_name\"" + str(e) == 'Permission Denied: No permission to access "customer"."c_name"' + ) + + +def test_opt_clac(): + headers = {} + properties_hashable = frozenset(headers.items()) if headers else None + + manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "dataSource": "bigquery", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "main", + "table": "orders", + }, + "columns": [ + { + "name": "o_orderkey", + "type": "integer", + "columnLevelAccessControl": { + "name": "o_orderkey_access", + "requiredProperties": [ + { + "name": "session_level", + "required": False, + "defaultExpr": "2", + } + ], + "operator": "GREATER_THAN", + "threshold": "3", + }, + }, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderdate", "type": "date"}, + ], + "primaryKey": "o_orderkey", + }, + ], + } + + manifest_str = base64.b64encode(json.dumps(manifest).encode("utf-8")).decode( + "utf-8" + ) + + session_context = SessionContext(manifest_str, None, properties_hashable) + sql = "SELECT o_orderkey FROM my_catalog.my_schema.orders" + try: + session_context.transform_sql(sql) + except Exception as e: + assert ( + str(e) == 'Permission Denied: No permission to access "orders"."o_orderkey"' ) + + +def test_backward_compatible_check(): + manifest_with_clac = { + "catalog": "my_catalog", + "schema": "my_schema", + "dataSource": "bigquery", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "main", + "table": "orders", + }, + "columns": [ + { + "name": "o_orderkey", + "type": "integer", + "columnLevelAccessControl": { + "name": "o_orderkey_access", + "requiredProperties": [ + { + "name": "session_level", + "required": False, + "defaultExpr": "2", + } + ], + "operator": "GREATER_THAN", + "threshold": "3", + }, + }, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderdate", "type": "date"}, + ], + "primaryKey": "o_orderkey", + }, + ], + } + + manifest_with_clac_str = base64.b64encode( + json.dumps(manifest_with_clac).encode("utf-8") + ).decode("utf-8") + assert not is_backward_compatible(manifest_with_clac_str) + + manifest_with_rlac = { + "catalog": "my_catalog", + "schema": "my_schema", + "dataSource": "bigquery", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "main", + "table": "orders", + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderdate", "type": "date"}, + ], + "primaryKey": "o_orderkey", + "rowLevelAccessControls": [ + { + "name": "customer_access", + "requiredProperties": [ + { + "name": "session_user", + "required": False, + } + ], + "condition": "o_custkey = @session_user", + }, + ], + }, + ], + } + manifest_with_rlac_str = base64.b64encode( + json.dumps(manifest_with_rlac).encode("utf-8") + ).decode("utf-8") + assert not is_backward_compatible(manifest_with_rlac_str) + + manifest_backward = { + "catalog": "my_catalog", + "schema": "my_schema", + "dataSource": "bigquery", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "main", + "table": "orders", + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + {"name": "o_custkey", "type": "integer"}, + {"name": "o_orderdate", "type": "date"}, + ], + "primaryKey": "o_orderkey", + }, + ], + } + manifest_backward_str = base64.b64encode( + json.dumps(manifest_backward).encode("utf-8") + ).decode("utf-8") + assert is_backward_compatible(manifest_backward_str) diff --git a/wren-core/core/src/logical_plan/analyze/access_control.rs b/wren-core/core/src/logical_plan/analyze/access_control.rs index ac531c02b..2bc21afb4 100644 --- a/wren-core/core/src/logical_plan/analyze/access_control.rs +++ b/wren-core/core/src/logical_plan/analyze/access_control.rs @@ -56,7 +56,10 @@ pub fn collect_condition( spans: Spans::new(), })); } else { - let session_property = value.trim_start_matches("@").to_string(); + let session_property = value + .trim_start_matches("@") + .to_string() + .to_ascii_lowercase(); if !session_properties.contains(&session_property) { session_properties.insert(session_property); }