Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import wren_core
from fastapi import Request
from starlette.datastructures import Headers

Expand Down Expand Up @@ -48,9 +49,8 @@ def _filter_headers(header_string: str) -> bool:
return False


def exist_wren_variables_header(
headers: Headers,
) -> bool:
if headers is None:
def is_backward_compatible(manifest_str: str) -> bool:
try:
return wren_core.is_backward_compatible(manifest_str)
except Exception:
return False
return any(key.startswith(X_WREN_VARIABLE_PREFIX) for key in headers.keys())
12 changes: 6 additions & 6 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
X_CACHE_OVERRIDE,
X_CACHE_OVERRIDE_AT,
X_WREN_FALLBACK_DISABLE,
exist_wren_variables_header,
get_wren_headers,
is_backward_compatible,
verify_query_dto,
)
from app.mdl.core import get_session_context
Expand Down Expand Up @@ -191,7 +191,7 @@ async def query(
if (
java_engine_connector.client is None
or is_fallback_disable
or exist_wren_variables_header(headers)
or not is_backward_compatible(dto.manifest_str)
):
raise e

Expand Down Expand Up @@ -237,7 +237,7 @@ async def dry_plan(
if (
java_engine_connector.client is None
or is_fallback_disable
or exist_wren_variables_header(headers)
or not is_backward_compatible(dto.manifest_str)
):
raise e

Expand Down Expand Up @@ -285,7 +285,7 @@ async def dry_plan_for_data_source(
if (
java_engine_connector.client is None
or is_fallback_disable
or exist_wren_variables_header(headers)
or not is_backward_compatible(dto.manifest_str)
):
raise e

Expand Down Expand Up @@ -351,7 +351,7 @@ async def validate(
if (
java_engine_connector.client is None
or is_fallback_disable
or exist_wren_variables_header(headers)
or not is_backward_compatible(dto.manifest_str)
):
raise e

Expand Down Expand Up @@ -446,7 +446,7 @@ async def model_substitute(
if (
java_engine_connector.client is None
or is_fallback_disable
or exist_wren_variables_header(headers)
or not is_backward_compatible(dto.manifest_str)
):
raise e

Expand Down
2 changes: 1 addition & 1 deletion ibis-server/resources/function_list/postgres.csv
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ scalar,sign,numeric,,numeric,"Sign of number"
scalar,to_json,json,,boolean,"Convert to JSON"
scalar,to_number,numeric,,"text,text","Convert string to number"
scalar,unistr,varchar,,text,"Postgres: Evaluate escaped Unicode characters in the argument"
scalar,pg_sleep,,,"numeric","Sleep for specified time in seconds"
scalar,pg_sleep,,,"bigint","Sleep for specified time in seconds"
12 changes: 12 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sqlalchemy
from testcontainers.postgres import PostgresContainer

from app.config import get_config
from tests.conftest import file_path

pytestmark = pytest.mark.postgres
Expand All @@ -19,6 +20,9 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytestmark)


function_list_path = file_path("../resources/function_list")


@pytest.fixture(scope="module")
def postgres(request) -> PostgresContainer:
pg = PostgresContainer("postgres:16-alpine").start()
Expand Down Expand Up @@ -56,3 +60,11 @@ def connection_info(postgres: PostgresContainer) -> dict[str, str]:
def connection_url(connection_info: dict[str, str]):
info = connection_info
return f"postgres://{info['user']}:{info['password']}@{info['host']}:{info['port']}/{info['database']}"


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,92 @@ async def test_query_rlac(client, manifest_str, connection_info):
},
headers={X_WREN_VARIABLE_PREFIX + "session_user": "1"},
)
assert response.status_code == 200

manifest_rlac = {
"catalog": "wren",
"schema": "public",
"models": [
{
"name": "orders",
"tableReference": {"schema": "public", "table": "orders"},
"columns": [
{
"name": "orderkey",
"type": "varchar",
"expression": "cast(o_orderkey as varchar)",
}
],
"rowLevelAccessControls": [
{
"name": "rule",
"requiredProperties": [
{
"name": "session_user",
"required": False,
}
],
"condition": "orderkey = @session_user",
},
],
}
],
}

manifest_rlac_str = base64.b64encode(orjson.dumps(manifest_rlac)).decode("utf-8")

response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_rlac_str,
"sql": "SELECT orderkey FROM orders LIMIT 1",
},
)
assert response.status_code == 422

manifest_clac = {
"catalog": "wren",
"schema": "public",
"models": [
{
"name": "orders",
"tableReference": {"schema": "public", "table": "orders"},
"columns": [
{
"name": "orderkey",
"type": "varchar",
"expression": "cast(o_orderkey as varchar)",
},
{
"name": "custkey",
"type": "varchar",
"columnLevelAccessControl": {
"name": "o_custkey_access",
"requiredProperties": [
{
"name": "session_level",
"required": False,
"defaultExpr": "2",
}
],
"operator": "GREATER_THAN",
"threshold": "3",
},
"expression": "cast(o_custkey as varchar)",
},
],
}
],
}

manifest_clac_str = base64.b64encode(orjson.dumps(manifest_clac)).decode("utf-8")
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_clac_str,
"sql": "SELECT orderkey FROM orders LIMIT 1",
},
)
assert response.status_code == 422
14 changes: 2 additions & 12 deletions ibis-server/tests/routers/v3/connector/postgres/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pytest

from app.config import get_config
from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path
from tests.routers.v3.connector.postgres.conftest import base_url
from tests.conftest import DATAFUSION_FUNCTION_COUNT
from tests.routers.v3.connector.postgres.conftest import base_url, function_list_path

pytestmark = pytest.mark.functions

Expand All @@ -26,22 +26,12 @@
],
}

function_list_path = file_path("../resources/function_list")


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
config.set_remote_function_list_path(function_list_path)
yield
config.set_remote_function_list_path(None)


async def test_function_list(client):
config = get_config()

Expand Down
4 changes: 2 additions & 2 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,11 @@ async def test_connection_timeout(
},
headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second
)
assert response.status_code == 504
assert (
"Query was cancelled: canceling statement due to statement timeout"
in response.text
)
assert response.status_code == 504

# test connection_url way can also timeout
response = await client.post(
Expand All @@ -710,11 +710,11 @@ async def test_connection_timeout(
},
headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second
)
assert response.status_code == 504
assert (
"Query was cancelled: canceling statement due to statement timeout"
in response.text
)
assert response.status_code == 504


async def test_format_floating(client, manifest_str, connection_info):
Expand Down
18 changes: 11 additions & 7 deletions wren-core-py/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,19 +253,23 @@ impl PySessionContext {
let _ = visit_statements_mut(&mut statements, |stmt| {
if let Statement::Query(q) = stmt {
if let Some(limit) = &q.limit {
if let Expr::Value(ValueWithSpan { value: Value::Number(n, is), .. }) = limit {
if let Expr::Value(ValueWithSpan {
value: Value::Number(n, is),
..
}) = limit
{
if let Ok(curr) = n.parse::<usize>() {
if curr > pushdown {
q.limit = Some(Expr::Value(Value::Number(
pushdown.to_string(),
*is,
).into()));
q.limit = Some(Expr::Value(
Value::Number(pushdown.to_string(), *is).into(),
));
}
}
}
} else {
q.limit =
Some(Expr::Value(Value::Number(pushdown.to_string(), false).into()));
q.limit = Some(Expr::Value(
Value::Number(pushdown.to_string(), false).into(),
));
}
}
ControlFlow::<()>::Continue(())
Expand Down
1 change: 1 addition & 0 deletions wren-core-py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?;
m.add_function(wrap_pyfunction!(manifest::to_manifest, m)?)?;
m.add_function(wrap_pyfunction!(validation::validate_rlac_rule, m)?)?;
m.add_function(wrap_pyfunction!(manifest::is_backward_compatible, m)?)?;
Ok(())
}
18 changes: 18 additions & 0 deletions wren-core-py/src/manifest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ pub fn to_manifest(mdl_base64: &str) -> Result<Manifest, CoreError> {
Ok(manifest)
}

/// Check if the MDL can be used by the v2 wren core. If there are any access controls rules,
/// the MDL should be used by the v3 wren core only.
#[pyfunction]
pub fn is_backward_compatible(mdl_base64: &str) -> Result<bool, CoreError> {
let manifest = to_manifest(mdl_base64)?;
let ralc_exist = manifest
.models
.iter()
.all(|model| model.row_level_access_controls().is_empty());
let clac_exist = manifest.models.iter().all(|model| {
model
.columns
.iter()
.all(|column| column.column_level_access_control().is_none())
});
Ok(ralc_exist && clac_exist)
}

#[cfg(test)]
mod tests {
use crate::manifest::{to_json_base64, to_manifest, Manifest};
Expand Down
Loading
Loading