From 9543b1bdcad5ae42b860f942b3169e1bc130d6f5 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 16 Apr 2025 13:54:59 +0800 Subject: [PATCH 1/7] pass all existing test --- ibis-server/app/mdl/substitute.py | 14 ++++-- ibis-server/app/routers/v2/connector.py | 7 +-- ibis-server/app/routers/v3/connector.py | 9 ++-- .../routers/v2/connector/test_postgres.py | 48 +++++++++++++++++-- .../postgres/test_model_substitute.py | 9 ++-- 5 files changed, 70 insertions(+), 17 deletions(-) diff --git a/ibis-server/app/mdl/substitute.py b/ibis-server/app/mdl/substitute.py index 11d436058..26a869426 100644 --- a/ibis-server/app/mdl/substitute.py +++ b/ibis-server/app/mdl/substitute.py @@ -10,10 +10,11 @@ class ModelSubstitute: - def __init__(self, data_source: DataSource, manifest_str: str): + def __init__(self, data_source: DataSource, manifest_str: str, headers=None): self.data_source = data_source self.manifest = base64_to_dict(manifest_str) self.model_dict = self._build_model_dict(self.manifest["models"]) + self.headers = dict(headers) if headers else None @tracer.start_as_current_span("substitute", kind=trace.SpanKind.INTERNAL) def substitute(self, sql: str, write: str | None = None) -> str: @@ -44,8 +45,15 @@ def key(model): return {key(model): model for model in models if "tableReference" in model} def _find_model(self, source: exp.Table) -> dict | None: - catalog = source.catalog or "" - schema = source.db or "" + if source.catalog and source.db: + catalog = source.catalog + schema = source.db + elif self.headers is not None: + catalog = self.headers.get("x-user-catalog") + schema = self.headers.get("x-user-schema") + else: + catalog = "" + schema = "" table = source.name return self.model_dict.get(f"{catalog}.{schema}.{table}", None) diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index edc234801..0a0237027 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -271,6 +271,7 @@ async def dry_plan_for_data_source( async def model_substitute( data_source: DataSource, dto: TranspileDTO, + request: Request, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), headers: Annotated[str | None, Header()] = None, ) -> str: @@ -278,9 +279,9 @@ async def model_substitute( with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): - sql = ModelSubstitute(data_source, dto.manifest_str).substitute( - dto.sql, write="trino" - ) + sql = ModelSubstitute( + data_source, dto.manifest_str, dict(request.headers) + ).substitute(dto.sql, write="trino") Connector(data_source, dto.connection_info).dry_run( await Rewriter( dto.manifest_str, diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index f6fbec4cb..23c270ac7 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -1,6 +1,6 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Header, Query, Response +from fastapi import APIRouter, Depends, Header, Query, Request, Response from fastapi.responses import ORJSONResponse from loguru import logger from opentelemetry import trace @@ -258,6 +258,7 @@ def functions( async def model_substitute( data_source: DataSource, dto: TranspileDTO, + request: Request, headers: Annotated[str | None, Header()] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: @@ -266,7 +267,9 @@ async def model_substitute( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): try: - sql = ModelSubstitute(data_source, dto.manifest_str).substitute(dto.sql) + sql = ModelSubstitute( + data_source, dto.manifest_str, dict(request.headers) + ).substitute(dto.sql) Connector(data_source, dto.connection_info).dry_run( await Rewriter( dto.manifest_str, @@ -280,5 +283,5 @@ async def model_substitute( "Failed to execute v3 model-substitute, fallback to v2: {}", str(e) ) return await v2.connector.model_substitute( - data_source, dto, java_engine_connector, headers + data_source, dto, request, java_engine_connector, headers ) diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 014862dce..4f43f5eea 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -23,6 +23,7 @@ { "name": "Orders", "tableReference": { + "catalog": "test", "schema": "public", "table": "orders", }, @@ -710,12 +711,29 @@ async def test_dry_plan(client, manifest_str): async def test_model_substitute(client, manifest_str, postgres: PostgresContainer): connection_info = _to_connection_info(postgres) + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "sql": 'SELECT * FROM "public"."orders"', + "sql": 'SELECT * FROM "test"."public"."orders"', + }, + ) + assert response.status_code == 200 + assert ( + response.text + == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' + ) + + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test", "x-user-schema": "public"}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "orders"', }, ) assert response.status_code == 200 @@ -729,14 +747,36 @@ async def test_model_substitute_with_cte( client, manifest_str, postgres: PostgresContainer ): connection_info = _to_connection_info(postgres) + # Test with catalog and schema in SQL + response = await client.post( + url=f"{base_url}/model-substitute", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": """ + WITH orders_cte AS ( + SELECT * FROM "test"."public"."orders" + ) + SELECT * FROM orders_cte; + """, + }, + ) + assert response.status_code == 200 + assert ( + response.text + == '"WITH orders_cte AS (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") SELECT * FROM orders_cte"' + ) + + # Test without catalog and schema in SQL but in headers(x-user-xxx) response = await client.post( url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test", "x-user-schema": "public"}, json={ "connectionInfo": connection_info, "manifestStr": manifest_str, "sql": """ WITH orders_cte AS ( - SELECT * FROM "public"."orders" + SELECT * FROM "orders" ) SELECT * FROM orders_cte; """, @@ -760,7 +800,7 @@ async def test_model_substitute_with_subquery( "manifestStr": manifest_str, "sql": """ SELECT * FROM ( - SELECT * FROM "public"."orders" + SELECT * FROM "test"."public"."orders" ) AS orders_subquery; """, }, @@ -797,7 +837,7 @@ async def test_model_substitute_non_existent_column( json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "sql": 'SELECT x FROM "public"."orders" LIMIT 1', + "sql": 'SELECT x FROM "test"."public"."orders" LIMIT 1', }, ) assert response.status_code == 422 diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py index 1e04b0082..c2366b0ea 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py @@ -12,6 +12,7 @@ { "name": "Orders", "tableReference": { + "catalog": "test", "schema": "public", "table": "orders", }, @@ -34,7 +35,7 @@ async def test_model_substitute(client, manifest_str, connection_info): json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "sql": 'SELECT * FROM "public"."orders"', + "sql": 'SELECT * FROM "test"."public"."orders"', }, ) assert response.status_code == 200 @@ -52,7 +53,7 @@ async def test_model_substitute_with_cte(client, manifest_str, connection_info): "manifestStr": manifest_str, "sql": """ WITH orders_cte AS ( - SELECT * FROM "public"."orders" + SELECT * FROM "test"."public"."orders" ) SELECT * FROM orders_cte; """, @@ -73,7 +74,7 @@ async def test_model_substitute_with_subquery(client, manifest_str, connection_i "manifestStr": manifest_str, "sql": """ SELECT * FROM ( - SELECT * FROM "public"."orders" + SELECT * FROM "test"."public"."orders" ) AS orders_subquery; """, }, @@ -106,7 +107,7 @@ async def test_model_substitute_non_existent_column( json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "sql": 'SELECT x FROM "public"."orders" LIMIT 1', + "sql": 'SELECT x FROM "test"."public"."orders" LIMIT 1', }, ) assert response.status_code == 422 From 3334f9a0a769edc993ab7e9d1ab5e2b183d89fd0 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 16 Apr 2025 14:08:56 +0800 Subject: [PATCH 2/7] add default header test cases --- .../routers/v2/connector/test_postgres.py | 49 +++++++++ .../postgres/test_model_substitute.py | 102 ++++++++++++++++++ 2 files changed, 151 insertions(+) diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 4f43f5eea..fc9de183f 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -793,6 +793,7 @@ async def test_model_substitute_with_subquery( client, manifest_str, postgres: PostgresContainer ): connection_info = _to_connection_info(postgres) + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ @@ -811,13 +812,47 @@ async def test_model_substitute_with_subquery( == '"SELECT * FROM (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") AS orders_subquery"' ) + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test", "x-user-schema": "public"}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": """ + SELECT * FROM ( + SELECT * FROM "orders" + ) AS orders_subquery; + """, + }, + ) + assert response.status_code == 200 + assert ( + response.text + == '"SELECT * FROM (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") AS orders_subquery"' + ) + async def test_model_substitute_out_of_scope( client, manifest_str, postgres: PostgresContainer ): connection_info = _to_connection_info(postgres) + # Test with catalog and schema in SQL + response = await client.post( + url=f"{base_url}/model-substitute", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Nation" LIMIT 1', + }, + ) + assert response.status_code == 422 + assert response.text == 'Model not found: "Nation"' + + # Test without catalog and schema in SQL but in headers(x-user-xxx) response = await client.post( url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test", "x-user-schema": "public"}, json={ "connectionInfo": connection_info, "manifestStr": manifest_str, @@ -832,6 +867,7 @@ async def test_model_substitute_non_existent_column( client, manifest_str, postgres: PostgresContainer ): connection_info = _to_connection_info(postgres) + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ @@ -843,6 +879,19 @@ async def test_model_substitute_non_existent_column( assert response.status_code == 422 assert 'column "x" does not exist' in response.text + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test", "x-user-schema": "public"}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT x FROM "orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + assert 'column "x" does not exist' in response.text + def _to_connection_info(pg: PostgresContainer): return { diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py index c2366b0ea..c566f3275 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py @@ -30,6 +30,7 @@ def manifest_str(): async def test_model_substitute(client, manifest_str, connection_info): + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ @@ -44,8 +45,28 @@ async def test_model_substitute(client, manifest_str, connection_info): == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' ) + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + "x-user-schema": "public", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "orders"', + }, + ) + assert response.status_code == 200 + assert ( + response.text + == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' + ) + async def test_model_substitute_with_cte(client, manifest_str, connection_info): + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ @@ -65,8 +86,33 @@ async def test_model_substitute_with_cte(client, manifest_str, connection_info): == '"WITH orders_cte AS (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") SELECT * FROM orders_cte"' ) + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + "x-user-schema": "public", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": """ + WITH orders_cte AS ( + SELECT * FROM "orders" + ) + SELECT * FROM orders_cte; + """, + }, + ) + assert response.status_code == 200 + assert ( + response.text + == '"WITH orders_cte AS (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") SELECT * FROM orders_cte"' + ) + async def test_model_substitute_with_subquery(client, manifest_str, connection_info): + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ @@ -85,10 +131,50 @@ async def test_model_substitute_with_subquery(client, manifest_str, connection_i == '"SELECT * FROM (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") AS orders_subquery"' ) + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + "x-user-schema": "public", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": """ + SELECT * FROM ( + SELECT * FROM "orders" + ) AS orders_subquery; + """, + }, + ) + assert response.status_code == 200 + assert ( + response.text + == '"SELECT * FROM (SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\") AS orders_subquery"' + ) + async def test_model_substitute_out_of_scope(client, manifest_str, connection_info): + # Test with catalog and schema in SQL + response = await client.post( + url=f"{base_url}/model-substitute", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Nation" LIMIT 1', + }, + ) + assert response.status_code == 422 + assert response.text == 'Model not found: "Nation"' + + # Test without catalog and schema in SQL but in headers(x-user-xxx) response = await client.post( url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + "x-user-schema": "public", + }, json={ "connectionInfo": connection_info, "manifestStr": manifest_str, @@ -102,6 +188,7 @@ async def test_model_substitute_out_of_scope(client, manifest_str, connection_in async def test_model_substitute_non_existent_column( client, manifest_str, connection_info ): + # Test with catalog and schema in SQL response = await client.post( url=f"{base_url}/model-substitute", json={ @@ -111,3 +198,18 @@ async def test_model_substitute_non_existent_column( }, ) assert response.status_code == 422 + + # Test without catalog and schema in SQL but in headers(x-user-xxx) + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + "x-user-schema": "public", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT x FROM "orders" LIMIT 1', + }, + ) + assert response.status_code == 422 From 1d3319916eeee5b7840898f5c5e93262a468fde2 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 16 Apr 2025 17:28:01 +0800 Subject: [PATCH 3/7] remove request param --- ibis-server/app/dependencies.py | 7 +++++++ ibis-server/app/routers/v2/connector.py | 12 ++++++------ ibis-server/app/routers/v3/connector.py | 12 ++++++------ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index 65c9d8adc..4c9592a19 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -1,3 +1,6 @@ +from fastapi import Request +from starlette.datastructures import Headers + from app.model import QueryDTO from app.model.data_source import DataSource @@ -5,3 +8,7 @@ # Rebuild model to validate the dto is correct via validation of the pydantic def verify_query_dto(data_source: DataSource, dto: QueryDTO): data_source.get_dto_type()(**dto.model_dump(by_alias=True)) + + +def get_wren_headers(request: Request) -> Headers: + return request.headers diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index 0a0237027..d985836be 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -4,8 +4,9 @@ from fastapi.responses import ORJSONResponse from loguru import logger from opentelemetry import trace +from starlette.datastructures import Headers -from app.dependencies import verify_query_dto +from app.dependencies import get_wren_headers, verify_query_dto from app.mdl.java_engine import JavaEngineConnector from app.mdl.rewriter import Rewriter from app.mdl.substitute import ModelSubstitute @@ -271,17 +272,16 @@ async def dry_plan_for_data_source( async def model_substitute( data_source: DataSource, dto: TranspileDTO, - request: Request, + headers: Annotated[Headers, Depends(get_wren_headers)], java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[str | None, Header()] = None, ) -> str: span_name = f"v2_model_substitute_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): - sql = ModelSubstitute( - data_source, dto.manifest_str, dict(request.headers) - ).substitute(dto.sql, write="trino") + sql = ModelSubstitute(data_source, dto.manifest_str, dict(headers)).substitute( + dto.sql, write="trino" + ) Connector(data_source, dto.connection_info).dry_run( await Rewriter( dto.manifest_str, diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 23c270ac7..ded0abb37 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -1,12 +1,13 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Header, Query, Request, Response +from fastapi import APIRouter, Depends, Header, Query, Response from fastapi.responses import ORJSONResponse from loguru import logger from opentelemetry import trace +from starlette.datastructures import Headers from app.config import get_config -from app.dependencies import verify_query_dto +from app.dependencies import get_wren_headers, verify_query_dto from app.mdl.core import get_session_context from app.mdl.java_engine import JavaEngineConnector from app.mdl.rewriter import Rewriter @@ -258,8 +259,7 @@ def functions( async def model_substitute( data_source: DataSource, dto: TranspileDTO, - request: Request, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)], java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: span_name = f"v3_model-substitute_{data_source}" @@ -268,7 +268,7 @@ async def model_substitute( ): try: sql = ModelSubstitute( - data_source, dto.manifest_str, dict(request.headers) + data_source, dto.manifest_str, dict(headers) ).substitute(dto.sql) Connector(data_source, dto.connection_info).dry_run( await Rewriter( @@ -283,5 +283,5 @@ async def model_substitute( "Failed to execute v3 model-substitute, fallback to v2: {}", str(e) ) return await v2.connector.model_substitute( - data_source, dto, request, java_engine_connector, headers + data_source, dto, headers, java_engine_connector ) From dc16e7c5968263b33ea472e14b19050b4cb704cf Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 16 Apr 2025 17:48:43 +0800 Subject: [PATCH 4/7] add partial header test --- ibis-server/app/mdl/substitute.py | 7 ++++ .../routers/v2/connector/test_postgres.py | 36 ++++++++++++++++ .../postgres/test_model_substitute.py | 42 +++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/ibis-server/app/mdl/substitute.py b/ibis-server/app/mdl/substitute.py index 26a869426..276c1354f 100644 --- a/ibis-server/app/mdl/substitute.py +++ b/ibis-server/app/mdl/substitute.py @@ -1,3 +1,4 @@ +from loguru import logger from opentelemetry import trace from sqlglot import exp, parse_one from sqlglot.optimizer.scope import build_scope @@ -40,6 +41,12 @@ def substitute(self, sql: str, write: str | None = None) -> str: def _build_model_dict(models) -> dict: def key(model): table_ref = model["tableReference"] + + if not table_ref.get("catalog") and not table_ref.get("schema"): + logger.debug( + "Try to substitute a tableReference has empty catalog and empty schema" + ) + return f"{table_ref.get('catalog', '')}.{table_ref.get('schema', '')}.{table_ref.get('table', '')}" return {key(model): model for model in models if "tableReference" in model} diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index fc9de183f..28a1465db 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -742,6 +742,42 @@ async def test_model_substitute(client, manifest_str, postgres: PostgresContaine == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' ) + # Test only have x-user-catalog + response = await client.post( + url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test"}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "orders"', + }, + ) + assert response.status_code == 422 + + # Test only have x-user-catalog but have schema in SQL + response = await client.post( + url=f"{base_url}/model-substitute", + headers={"x-user-catalog": "test"}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "public"."orders"', + }, + ) + assert response.status_code == 422 + + # Test only have x-user-schema + response = await client.post( + url=f"{base_url}/model-substitute", + headers={"x-user-schema": "public"}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "orders"', + }, + ) + assert response.status_code == 422 + async def test_model_substitute_with_cte( client, manifest_str, postgres: PostgresContainer diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py index c566f3275..bc7801bdd 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py @@ -64,6 +64,48 @@ async def test_model_substitute(client, manifest_str, connection_info): == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' ) + # Test only have x-user-catalog + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "orders"', + }, + ) + assert response.status_code == 422 + + # Test only have x-user-catalog but have schema in SQL + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-catalog": "test", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "public"."orders"', + }, + ) + assert response.status_code == 422 + + # Test only have x-user-schema + response = await client.post( + url=f"{base_url}/model-substitute", + headers={ + "x-user-schema": "public", + }, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "orders"', + }, + ) + assert response.status_code == 422 + async def test_model_substitute_with_cte(client, manifest_str, connection_info): # Test with catalog and schema in SQL From 9d574f76aa944b17045514a6cf67169504b3040e Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 16 Apr 2025 17:54:28 +0800 Subject: [PATCH 5/7] replace fastapi header with starlette header --- ibis-server/app/routers/v2/connector.py | 16 ++++++++-------- ibis-server/app/routers/v3/connector.py | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index d985836be..7122fd6a3 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -1,6 +1,6 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Header, Query, Request, Response +from fastapi import APIRouter, Depends, Query, Request, Response from fastapi.responses import ORJSONResponse from loguru import logger from opentelemetry import trace @@ -58,7 +58,7 @@ async def query( limit: int | None = Query(None, description="limit the number of rows returned"), java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager), - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> Response: span_name = f"v2_query_{data_source}" if dry_run: @@ -161,7 +161,7 @@ async def validate( rule_name: str, dto: ValidateDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> Response: span_name = f"v2_validate_{data_source}" with tracer.start_as_current_span( @@ -188,7 +188,7 @@ async def validate( def get_table_list( data_source: DataSource, dto: MetadataDTO, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> list[Table]: span_name = f"v2_metadata_tables_{data_source}" with tracer.start_as_current_span( @@ -208,7 +208,7 @@ def get_table_list( def get_constraints( data_source: DataSource, dto: MetadataDTO, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> list[Constraint]: span_name = f"v2_metadata_constraints_{data_source}" with tracer.start_as_current_span( @@ -232,7 +232,7 @@ def get_db_version(data_source: DataSource, dto: MetadataDTO) -> str: async def dry_plan( dto: DryPlanDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> str: with tracer.start_as_current_span( name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers) @@ -251,7 +251,7 @@ async def dry_plan_for_data_source( data_source: DataSource, dto: DryPlanDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> str: span_name = f"v2_dry_plan_{data_source}" with tracer.start_as_current_span( @@ -272,7 +272,7 @@ async def dry_plan_for_data_source( async def model_substitute( data_source: DataSource, dto: TranspileDTO, - headers: Annotated[Headers, Depends(get_wren_headers)], + headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: span_name = f"v2_model_substitute_{data_source}" diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index ded0abb37..8608bda88 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -1,6 +1,6 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Header, Query, Response +from fastapi import APIRouter, Depends, Query, Response from fastapi.responses import ORJSONResponse from loguru import logger from opentelemetry import trace @@ -52,7 +52,7 @@ async def query( bool, Query(alias="overrideCache", description="ovrride the exist cache") ] = False, limit: int | None = Query(None, description="limit the number of rows returned"), - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager), ) -> Response: @@ -156,7 +156,7 @@ async def query( @router.post("/dry-plan", description="get the planned WrenSQL") async def dry_plan( dto: DryPlanDTO, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: with tracer.start_as_current_span( @@ -180,7 +180,7 @@ async def dry_plan( async def dry_plan_for_data_source( data_source: DataSource, dto: DryPlanDTO, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: span_name = f"v3_dry_plan_{data_source}" @@ -209,7 +209,7 @@ async def validate( data_source: DataSource, rule_name: str, dto: ValidateDTO, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> Response: span_name = f"v3_validate_{data_source}" @@ -240,7 +240,7 @@ async def validate( ) def functions( data_source: DataSource, - headers: Annotated[str | None, Header()] = None, + headers: Annotated[Headers, Depends(get_wren_headers)] = None, ) -> Response: span_name = f"v3_functions_{data_source}" with tracer.start_as_current_span( From 7fcd70fe263290d2994d82a55dc34e307a1fce4a Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Thu, 17 Apr 2025 00:46:30 +0800 Subject: [PATCH 6/7] allow only defined catalog header --- ibis-server/app/mdl/substitute.py | 31 +++++++++++-------- ibis-server/app/routers/v2/connector.py | 2 +- ibis-server/app/routers/v3/connector.py | 6 ++-- .../routers/v2/connector/test_postgres.py | 6 +++- .../postgres/test_model_substitute.py | 6 +++- 5 files changed, 32 insertions(+), 19 deletions(-) diff --git a/ibis-server/app/mdl/substitute.py b/ibis-server/app/mdl/substitute.py index 276c1354f..30b9ac78f 100644 --- a/ibis-server/app/mdl/substitute.py +++ b/ibis-server/app/mdl/substitute.py @@ -1,4 +1,3 @@ -from loguru import logger from opentelemetry import trace from sqlglot import exp, parse_one from sqlglot.optimizer.scope import build_scope @@ -42,25 +41,31 @@ def _build_model_dict(models) -> dict: def key(model): table_ref = model["tableReference"] - if not table_ref.get("catalog") and not table_ref.get("schema"): - logger.debug( - "Try to substitute a tableReference has empty catalog and empty schema" - ) - - return f"{table_ref.get('catalog', '')}.{table_ref.get('schema', '')}.{table_ref.get('table', '')}" + # fully qualified catalog.schema.table + if table_ref.get("catalog") and table_ref.get("schema"): + return f"{table_ref.get('catalog', '')}.{table_ref.get('schema', '')}.{table_ref.get('table', '')}" + # schema.table + elif table_ref.get("schema"): + return f"{table_ref.get('schema', '')}.{table_ref.get('table', '')}" + # table + else: + return table_ref.get("table", "") return {key(model): model for model in models if "tableReference" in model} def _find_model(self, source: exp.Table) -> dict | None: - if source.catalog and source.db: + # Determine catalog + if source.catalog: catalog = source.catalog + else: + catalog = self.headers.get("x-user-catalog", "") if self.headers else "" + + # Determine schema + if source.db: schema = source.db - elif self.headers is not None: - catalog = self.headers.get("x-user-catalog") - schema = self.headers.get("x-user-schema") else: - catalog = "" - schema = "" + schema = self.headers.get("x-user-schema", "") if self.headers else "" + table = source.name return self.model_dict.get(f"{catalog}.{schema}.{table}", None) diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index 7122fd6a3..bf62cd479 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -279,7 +279,7 @@ async def model_substitute( with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): - sql = ModelSubstitute(data_source, dto.manifest_str, dict(headers)).substitute( + sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute( dto.sql, write="trino" ) Connector(data_source, dto.connection_info).dry_run( diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 8608bda88..aef5cc892 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -267,9 +267,9 @@ async def model_substitute( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): try: - sql = ModelSubstitute( - data_source, dto.manifest_str, dict(headers) - ).substitute(dto.sql) + sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute( + dto.sql + ) Connector(data_source, dto.connection_info).dry_run( await Rewriter( dto.manifest_str, diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index 28a1465db..1ef80f39e 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -764,7 +764,11 @@ async def test_model_substitute(client, manifest_str, postgres: PostgresContaine "sql": 'SELECT * FROM "public"."orders"', }, ) - assert response.status_code == 422 + assert response.status_code == 200 + assert ( + response.text + == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' + ) # Test only have x-user-schema response = await client.post( diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py index bc7801bdd..b05af9f74 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py @@ -90,7 +90,11 @@ async def test_model_substitute(client, manifest_str, connection_info): "sql": 'SELECT * FROM "public"."orders"', }, ) - assert response.status_code == 422 + assert response.status_code == 200 + assert ( + response.text + == '"SELECT * FROM \\"my_catalog\\".\\"my_schema\\".\\"Orders\\" AS \\"orders\\""' + ) # Test only have x-user-schema response = await client.post( From 51a53383fc94691eba7300b8436c0cf94fbf4fd0 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Thu, 17 Apr 2025 00:56:43 +0800 Subject: [PATCH 7/7] chore: trigger CI