Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
09ae1ea
add row level access control
goldmedal Apr 18, 2025
5c7e680
rename session variable to session properties
goldmedal Apr 18, 2025
1c380c8
use insta for testing
goldmedal Apr 22, 2025
d13e1b8
add seesion properties
goldmedal Apr 22, 2025
412ef6d
implement row level access control
goldmedal Apr 22, 2025
2ca197a
refactor test to use insta
goldmedal Apr 22, 2025
fc5b970
add test and fix filter position
goldmedal Apr 22, 2025
c3e2188
add optional property test
goldmedal Apr 22, 2025
f3517f0
support for calcaulted field as the condition
goldmedal Apr 23, 2025
5378278
fmt
goldmedal Apr 23, 2025
5978cac
expose to python binding
goldmedal Apr 24, 2025
c86c1bf
ensure the name is case insensitive
goldmedal Apr 24, 2025
13c74ab
expose to the wren engine api
goldmedal Apr 24, 2025
9685b28
enhance example
goldmedal Apr 24, 2025
4265c57
remove unused file
goldmedal Apr 24, 2025
8a4035d
remove unused file
goldmedal Apr 24, 2025
a1fb5e9
rename example
goldmedal Apr 24, 2025
a9747e2
fix fmt
goldmedal Apr 24, 2025
fae8664
fix file name
goldmedal Apr 24, 2025
8e9898a
move insta to dev
goldmedal Apr 25, 2025
cb563ef
prevent the invalid property value
goldmedal Apr 25, 2025
6ca5719
use reference for rlac
goldmedal Apr 25, 2025
6f86e14
refactor and update lock
goldmedal Apr 25, 2025
22bc479
fix example default
goldmedal Apr 25, 2025
6e627ae
add todo comment
goldmedal Apr 25, 2025
b587e3f
fix typo and check for default is empty
goldmedal Apr 25, 2025
59dcc79
fix compile
goldmedal Apr 30, 2025
e9bc623
use hashset to avoid duplicate result
goldmedal Apr 30, 2025
15e307b
fix header check
goldmedal May 5, 2025
1c2b05d
fix missing header
goldmedal May 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from app.model.data_source import DataSource

X_WREN_FALLBACK_DISABLE = "x-wren-fallback_disable"
X_WREN_VARIABLE_PREFIX = "x-wren-variable-"


# Rebuild model to validate the dto is correct via validation of the pydantic
Expand Down Expand Up @@ -35,3 +36,11 @@ def _filter_headers(header_string: str) -> bool:
elif header_string == "sentry-trace":
return True
return False


def exist_wren_variables_header(
headers: Headers,
) -> bool:
if headers is None:
return False
return any(key.startswith(X_WREN_VARIABLE_PREFIX) for key in headers.keys())
31 changes: 27 additions & 4 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from opentelemetry import trace

from app.config import get_config
from app.dependencies import X_WREN_VARIABLE_PREFIX
from app.mdl.core import (
get_manifest_extractor,
get_session_context,
Expand All @@ -32,10 +33,12 @@ def __init__(
data_source: DataSource = None,
java_engine_connector: JavaEngineConnector = None,
experiment=False,
properties: dict | None = None,
):
self.manifest_str = manifest_str
self.data_source = data_source
self.experiment = experiment
self.properties = properties
if experiment:
function_path = get_config().get_remote_function_list_path(data_source)
self._rewriter = EmbeddedEngineRewriter(function_path)
Expand All @@ -54,7 +57,7 @@ async def rewrite(self, sql: str) -> str:
self._extract_manifest(self.manifest_str, sql) or self.manifest_str
)
logger.debug("Extracted manifest: {}", manifest_str)
planned_sql = await self._rewriter.rewrite(manifest_str, sql)
planned_sql = await self._rewriter.rewrite(manifest_str, sql, self.properties)
logger.debug("Planned SQL: {}", planned_sql)
dialect_sql = self._transpile(planned_sql) if self.data_source else planned_sql
logger.debug("Dialect SQL: {}", dialect_sql)
Expand Down Expand Up @@ -93,7 +96,9 @@ def __init__(self, java_engine_connector: JavaEngineConnector):
self.java_engine_connector = java_engine_connector

@tracer.start_as_current_span("external_rewrite", kind=trace.SpanKind.CLIENT)
async def rewrite(self, manifest_str: str, sql: str) -> str:
async def rewrite(
self, manifest_str: str, sql: str, properties: dict | None = None
) -> str:
try:
return await self.java_engine_connector.dry_plan(manifest_str, sql)
except httpx.ConnectError as e:
Expand All @@ -113,13 +118,31 @@ def __init__(self, function_path: str):
self.function_path = function_path

@tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL)
async def rewrite(self, manifest_str: str, sql: str) -> str:
async def rewrite(
self, manifest_str: str, sql: str, properties: dict | None = None
) -> str:
try:
session_context = get_session_context(manifest_str, self.function_path)
return await to_thread.run_sync(session_context.transform_sql, sql)
return await to_thread.run_sync(
session_context.transform_sql,
sql,
self.get_session_properties(properties),
)
except Exception as e:
raise RewriteError(str(e))

def get_session_properties(self, properties: dict) -> dict | None:
if properties is None:
return None
# filter the properties which name starts with "x-wren-variable-"
# and remove the prefix "x-wren-variable-"

return {
k.replace(X_WREN_VARIABLE_PREFIX, ""): v
for k, v in properties.items()
if k.startswith(X_WREN_VARIABLE_PREFIX)
}

@staticmethod
def handle_extract_exception(e: Exception):
raise RewriteError(str(e))
Expand Down
49 changes: 39 additions & 10 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from starlette.datastructures import Headers

from app.config import get_config
from app.dependencies import X_WREN_FALLBACK_DISABLE, get_wren_headers, verify_query_dto
from app.dependencies import (
X_WREN_FALLBACK_DISABLE,
exist_wren_variables_header,
get_wren_headers,
verify_query_dto,
)
from app.mdl.core import get_session_context
from app.mdl.java_engine import JavaEngineConnector
from app.mdl.rewriter import Rewriter
Expand Down Expand Up @@ -72,7 +77,10 @@ async def query(
if dry_run:
sql = pushdown_limit(dto.sql, limit)
rewritten_sql = await Rewriter(
dto.manifest_str, data_source=data_source, experiment=True
dto.manifest_str,
data_source=data_source,
experiment=True,
properties=dict(headers),
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)
connector.dry_run(rewritten_sql)
Expand Down Expand Up @@ -102,7 +110,10 @@ async def query(
else:
sql = pushdown_limit(dto.sql, limit)
rewritten_sql = await Rewriter(
dto.manifest_str, data_source=data_source, experiment=True
dto.manifest_str,
data_source=data_source,
experiment=True,
properties=dict(headers),
).rewrite(sql)
connector = Connector(data_source, dto.connection_info)
result = connector.query(rewritten_sql, limit=limit)
Expand Down Expand Up @@ -145,7 +156,9 @@ async def query(
headers.get(X_WREN_FALLBACK_DISABLE)
and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false"))
)
if is_fallback_disable:
# because the v2 API doesn't support row-level access control,
# we don't fallback to v2 if the header include row-level access control properties.
if is_fallback_disable or exist_wren_variables_header(headers):
raise e

logger.warning(
Expand Down Expand Up @@ -176,13 +189,17 @@ async def dry_plan(
name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers)
) as span:
try:
return await Rewriter(dto.manifest_str, experiment=True).rewrite(dto.sql)
return await Rewriter(
dto.manifest_str, experiment=True, properties=dict(headers)
).rewrite(dto.sql)
except Exception as e:
is_fallback_disable = bool(
headers.get(X_WREN_FALLBACK_DISABLE)
and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false"))
)
if is_fallback_disable:
# because the v2 API doesn't support row-level access control,
# we don't fallback to v2 if the header include row-level access control properties.
if is_fallback_disable or exist_wren_variables_header(headers):
raise e

logger.warning(
Expand Down Expand Up @@ -213,14 +230,19 @@ async def dry_plan_for_data_source(
) as span:
try:
return await Rewriter(
dto.manifest_str, data_source=data_source, experiment=True
dto.manifest_str,
data_source=data_source,
experiment=True,
properties=dict(headers),
).rewrite(dto.sql)
except Exception as e:
is_fallback_disable = bool(
headers.get(X_WREN_FALLBACK_DISABLE)
and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false"))
)
if is_fallback_disable:
# because the v2 API doesn't support row-level access control,
# we don't fallback to v2 if the header include row-level access control properties.
if is_fallback_disable or exist_wren_variables_header(headers):
raise e

logger.warning(
Expand Down Expand Up @@ -254,7 +276,12 @@ async def validate(
try:
validator = Validator(
Connector(data_source, dto.connection_info),
Rewriter(dto.manifest_str, data_source=data_source, experiment=True),
Rewriter(
dto.manifest_str,
data_source=data_source,
experiment=True,
properties=dict(headers),
),
)
await validator.validate(rule_name, dto.parameters, dto.manifest_str)
return Response(status_code=204)
Expand All @@ -263,7 +290,9 @@ async def validate(
headers.get(X_WREN_FALLBACK_DISABLE)
and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false"))
)
if is_fallback_disable:
# because the v2 API doesn't support row-level access control,
# we don't fallback to v2 if the header include row-level access control properties.
if is_fallback_disable or exist_wren_variables_header(headers):
raise e

logger.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import orjson
import pytest

from app.dependencies import X_WREN_FALLBACK_DISABLE
from app.dependencies import X_WREN_FALLBACK_DISABLE, X_WREN_VARIABLE_PREFIX
from tests.routers.v3.connector.postgres.conftest import base_url

# It's not a valid manifest for v3. We expect the query to fail and fallback to v2.
Expand Down Expand Up @@ -383,3 +383,16 @@ async def test_validate(client, manifest_str, connection_info):
},
)
assert response.status_code == 422


async def test_query_rlac(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT orderkey FROM orders LIMIT 1",
},
headers={X_WREN_VARIABLE_PREFIX + "session_user": "1"},
)
assert response.status_code == 422
47 changes: 47 additions & 0 deletions ibis-server/tests/routers/v3/connector/postgres/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import orjson
import pytest

from app.dependencies import X_WREN_VARIABLE_PREFIX
from tests.routers.v3.connector.postgres.conftest import base_url

manifest = {
Expand Down Expand Up @@ -78,6 +79,18 @@
"expression": "sum(orders.o_totalprice_double)",
},
],
"rowLevelAccessControls": [
{
"name": "customer_access",
"requiredProperties": [
{
"name": "session_user",
"required": False,
}
],
"condition": "c_name = @session_user",
},
],
"primaryKey": "c_custkey",
},
],
Expand Down Expand Up @@ -478,3 +491,37 @@ async def test_limit_pushdown(client, manifest_str, connection_info):
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 10


async def test_rlac_query(client, manifest_str, connection_info):
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT c_name FROM customer",
},
headers={
X_WREN_VARIABLE_PREFIX + "session_user": "'Customer#000000001'",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "Customer#000000001"

response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT c_name FROM customer",
},
headers={
X_WREN_VARIABLE_PREFIX + "SESSION_USER": "'Customer#000000001'",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1
assert result["data"][0][0] == "Customer#000000001"
Loading