From 92367942e4c64b4e16b4f1dd0c2d193ac4e83bac Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 29 Apr 2025 16:32:16 +0800 Subject: [PATCH 1/7] refine code style and migrate message --- ibis-server/app/routers/v2/connector.py | 45 ++++++++++++++--- ibis-server/app/routers/v3/connector.py | 66 ++++++++++++++----------- ibis-server/app/util.py | 18 +++++++ 3 files changed, 93 insertions(+), 36 deletions(-) diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index bf62cd479..49d1ffff0 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -22,7 +22,7 @@ from app.model.metadata.factory import MetadataFactory from app.model.validator import Validator from app.query_cache import QueryCacheManager -from app.util import build_context, pushdown_limit, to_json +from app.util import build_context, get_fallback_message, pushdown_limit, to_json router = APIRouter(prefix="/connector", tags=["connector"]) tracer = trace.get_tracer(__name__) @@ -43,6 +43,7 @@ def get_query_cache_manager(request: Request) -> QueryCacheManager: description="query the specified data source", ) async def query( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, dto: QueryDTO, dry_run: Annotated[ @@ -58,7 +59,7 @@ async def query( limit: int | None = Query(None, description="limit the number of rows returned"), java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> Response: span_name = f"v2_query_{data_source}" if dry_run: @@ -148,6 +149,11 @@ async def query( response = ORJSONResponse(to_json(result)) response.headers["X-Cache-Hit"] = "false" + if is_fallback: + get_fallback_message( + logger, "query", data_source, dto.manifest_str, dto.sql + ) + return response @@ -157,11 +163,12 @@ async def query( description="validate the specified rule", ) async def validate( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, rule_name: str, dto: ValidateDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> Response: span_name = f"v2_validate_{data_source}" with tracer.start_as_current_span( @@ -176,7 +183,12 @@ async def validate( ), ) await validator.validate(rule_name, dto.parameters, dto.manifest_str) - return Response(status_code=204) + response = Response(status_code=204) + if is_fallback: + get_fallback_message( + logger, "validate", data_source, dto.manifest_str, None + ) + return response @router.post( @@ -230,17 +242,23 @@ def get_db_version(data_source: DataSource, dto: MetadataDTO) -> str: @router.post("/dry-plan", deprecated=True, description="get the planned WrenSQL") async def dry_plan( + headers: Annotated[Headers, Depends(get_wren_headers)], dto: DryPlanDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> str: with tracer.start_as_current_span( name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers) ): - return await Rewriter( + sql = await Rewriter( dto.manifest_str, java_engine_connector=java_engine_connector ).rewrite(dto.sql) + if is_fallback: + get_fallback_message(logger, "dry_plan", None, dto.manifest_str, dto.sql) + + return sql + @router.post( "/{data_source}/dry-plan", @@ -248,20 +266,26 @@ async def dry_plan( description="get the dialect SQL for the specified data source", ) async def dry_plan_for_data_source( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, dto: DryPlanDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> str: span_name = f"v2_dry_plan_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): - return await Rewriter( + sql = await Rewriter( dto.manifest_str, data_source=data_source, java_engine_connector=java_engine_connector, ).rewrite(dto.sql) + if is_fallback: + get_fallback_message( + logger, "dry_plan", data_source, dto.manifest_str, dto.sql + ) + return sql @router.post( @@ -274,6 +298,7 @@ async def model_substitute( dto: TranspileDTO, headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), + is_fallback: bool | None = None, ) -> str: span_name = f"v2_model_substitute_{data_source}" with tracer.start_as_current_span( @@ -289,4 +314,8 @@ async def model_substitute( java_engine_connector=java_engine_connector, ).rewrite(sql) ) + if is_fallback: + get_fallback_message( + logger, "model_substitute", data_source, dto.manifest_str, dto.sql + ) return sql diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index aef5cc892..4ed6e37eb 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -29,9 +29,6 @@ router = APIRouter(prefix="/connector", tags=["connector"]) tracer = trace.get_tracer(__name__) -MIGRATION_MESSAGE = "Wren engine is migrating to Rust version now. \ - Wren AI team are appreciate if you can provide the error messages and related logs for us." - @router.post( "/{data_source}/query", @@ -137,26 +134,25 @@ async def query( return response except Exception as e: logger.warning( - "Failed to execute v3 query, fallback to v2: {}\n" + MIGRATION_MESSAGE, - str(e), + "Failed to execute v3 query, try to fallback to v2: {}\n", str(e) ) return await v2.connector.query( - data_source, - dto, - dry_run, - cache_enable, - override_cache, - limit, - java_engine_connector, - query_cache_manager, - headers, + data_source=data_source, + dto=dto, + dry_run=dry_run, + cache_enable=cache_enable, + override_cache=override_cache, + limit=limit, + java_engine_connector=java_engine_connector, + query_cache_manager=query_cache_manager, + headers=headers, ) @router.post("/dry-plan", description="get the planned WrenSQL") async def dry_plan( + headers: Annotated[Headers, Depends(get_wren_headers)], dto: DryPlanDTO, - headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: with tracer.start_as_current_span( @@ -166,11 +162,14 @@ async def dry_plan( return await Rewriter(dto.manifest_str, experiment=True).rewrite(dto.sql) except Exception as e: logger.warning( - "Failed to execute v3 dry-plan, fallback to v2: {}\n" - + MIGRATION_MESSAGE, - str(e), + "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e) + ) + return await v2.connector.dry_plan( + dto=dto, + java_engine_connector=java_engine_connector, + headers=headers, + is_fallback=True, ) - return await v2.connector.dry_plan(dto, java_engine_connector, headers) @router.post( @@ -178,9 +177,9 @@ async def dry_plan( description="get the dialect SQL for the specified data source", ) async def dry_plan_for_data_source( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, dto: DryPlanDTO, - headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: span_name = f"v3_dry_plan_{data_source}" @@ -193,12 +192,15 @@ async def dry_plan_for_data_source( ).rewrite(dto.sql) except Exception as e: logger.warning( - "Failed to execute v3 dry-plan, fallback to v2: {}\n" - + MIGRATION_MESSAGE, + "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e), ) return await v2.connector.dry_plan_for_data_source( - data_source, dto, java_engine_connector, headers + data_source=data_source, + dto=dto, + java_engine_connector=java_engine_connector, + headers=headers, + is_fallback=True, ) @@ -206,10 +208,10 @@ async def dry_plan_for_data_source( "/{data_source}/validate/{rule_name}", description="validate the specified rule" ) async def validate( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, rule_name: str, dto: ValidateDTO, - headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> Response: span_name = f"v3_validate_{data_source}" @@ -225,12 +227,16 @@ async def validate( return Response(status_code=204) except Exception as e: logger.warning( - "Failed to execute v3 validate, fallback to v2: {}\n" - + MIGRATION_MESSAGE, + "Failed to execute v3 validate, try to fallback to v2: {}", str(e), ) return await v2.connector.validate( - data_source, rule_name, dto, java_engine_connector, headers + data_source=data_source, + rule_name=rule_name, + dto=dto, + java_engine_connector=java_engine_connector, + headers=headers, + is_fallback=True, ) @@ -283,5 +289,9 @@ async def model_substitute( "Failed to execute v3 model-substitute, fallback to v2: {}", str(e) ) return await v2.connector.model_substitute( - data_source, dto, headers, java_engine_connector + data_source=data_source, + dto=dto, + headers=headers, + java_engine_connector=java_engine_connector, + is_fallback=True, ) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index 9addf0759..fd2272c7a 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -11,9 +11,15 @@ from opentelemetry.propagate import extract from pandas.core.dtypes.common import is_datetime64_any_dtype +from app.model.data_source import DataSource + tracer = trace.get_tracer(__name__) +MIGRATION_MESSAGE = "Wren engine is migrating to Rust version now. \ + Wren AI team are appreciate if you can provide the error messages and related logs for us." + + @tracer.start_as_current_span("base64_to_dict", kind=trace.SpanKind.INTERNAL) def base64_to_dict(base64_str: str) -> dict: return orjson.loads(base64.b64decode(base64_str).decode("utf-8")) @@ -105,3 +111,15 @@ def build_context(headers: Header) -> Context: def pushdown_limit(sql: str, limit: int | None) -> str: ctx = wren_core.SessionContext() return ctx.pushdown_limit(sql, limit) + + +def get_fallback_message( + logger, prefix: str, datasource: DataSource, mdl_hash: str, sql: str +) -> str: + if sql is not None: + sql = sql.replace("\n", " ") + + message = orjson.dumps( + {"datasource": datasource, "mdl_hash": mdl_hash, "sql": sql} + ).decode("utf-8") + logger.warning("Fallback to v2 {} -- {}\n{}", prefix, message, MIGRATION_MESSAGE) # noqa: PLE1205 From 5ac8fed707218f8da9edf63ab752c72f9c4ea315 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 29 Apr 2025 17:00:37 +0800 Subject: [PATCH 2/7] add tracing header for fallback --- ibis-server/app/dependencies.py | 21 ++++++++++++++++++++- ibis-server/app/routers/v3/connector.py | 15 ++++++++++----- ibis-server/app/util.py | 20 ++++++++++++++++++++ 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index 4c9592a19..8c5fa6aa7 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -11,4 +11,23 @@ def verify_query_dto(data_source: DataSource, dto: QueryDTO): def get_wren_headers(request: Request) -> Headers: - return request.headers + return Headers( + raw=list( + filter( + lambda t: _filter_headers(t[0].decode("latin-1")), + request.headers.raw, + ) + ) + ) + + +def _filter_headers(header_string: str) -> bool: + if header_string.startswith("x-wren-"): + return True + elif header_string == "traceparent": + return True + elif header_string == "tracestate": + return True + elif header_string == "sentry-trace": + return True + return False diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 4ed6e37eb..68be44cac 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -24,7 +24,7 @@ from app.query_cache import QueryCacheManager from app.routers import v2 from app.routers.v2.connector import get_java_engine_connector, get_query_cache_manager -from app.util import build_context, pushdown_limit, to_json +from app.util import append_fallback_context, build_context, pushdown_limit, to_json router = APIRouter(prefix="/connector", tags=["connector"]) tracer = trace.get_tracer(__name__) @@ -136,6 +136,7 @@ async def query( logger.warning( "Failed to execute v3 query, try to fallback to v2: {}\n", str(e) ) + headers = append_fallback_context(headers, span) return await v2.connector.query( data_source=data_source, dto=dto, @@ -157,13 +158,14 @@ async def dry_plan( ) -> str: with tracer.start_as_current_span( name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: return await Rewriter(dto.manifest_str, experiment=True).rewrite(dto.sql) except Exception as e: logger.warning( "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e) ) + headers = append_fallback_context(headers, span) return await v2.connector.dry_plan( dto=dto, java_engine_connector=java_engine_connector, @@ -185,7 +187,7 @@ async def dry_plan_for_data_source( span_name = f"v3_dry_plan_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: return await Rewriter( dto.manifest_str, data_source=data_source, experiment=True @@ -195,6 +197,7 @@ async def dry_plan_for_data_source( "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e), ) + headers = append_fallback_context(headers, span) return await v2.connector.dry_plan_for_data_source( data_source=data_source, dto=dto, @@ -217,7 +220,7 @@ async def validate( span_name = f"v3_validate_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: validator = Validator( Connector(data_source, dto.connection_info), @@ -230,6 +233,7 @@ async def validate( "Failed to execute v3 validate, try to fallback to v2: {}", str(e), ) + headers = append_fallback_context(headers, span) return await v2.connector.validate( data_source=data_source, rule_name=rule_name, @@ -271,7 +275,7 @@ async def model_substitute( span_name = f"v3_model-substitute_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute( dto.sql @@ -288,6 +292,7 @@ async def model_substitute( logger.warning( "Failed to execute v3 model-substitute, fallback to v2: {}", str(e) ) + headers = append_fallback_context(headers, span) return await v2.connector.model_substitute( data_source=data_source, dto=dto, diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index fd2272c7a..5ff8bad2a 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -7,9 +7,16 @@ import wren_core from fastapi import Header from opentelemetry import trace +from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.context import Context from opentelemetry.propagate import extract +from opentelemetry.trace import ( + NonRecordingSpan, + set_span_in_context, +) +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pandas.core.dtypes.common import is_datetime64_any_dtype +from starlette.datastructures import Headers from app.model.data_source import DataSource @@ -107,6 +114,19 @@ def build_context(headers: Header) -> Context: return extract(headers) +def append_fallback_context(headers: Header, span: trace.Span) -> Headers: + if headers is None: + headers = {} + else: + headers = dict(headers) + span = NonRecordingSpan(span.get_span_context()) + context = set_span_in_context(span) + # https://opentelemetry.io/docs/languages/python/propagation/ + W3CBaggagePropagator().inject(headers, context) + TraceContextTextMapPropagator().inject(headers, context) + return Headers(headers) + + @tracer.start_as_current_span("pushdown_limit", kind=trace.SpanKind.INTERNAL) def pushdown_limit(sql: str, limit: int | None) -> str: ctx = wren_core.SessionContext() From 7a44fc09c79754b0ceb48dd009453b3cf5b952ca Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 29 Apr 2025 17:34:00 +0800 Subject: [PATCH 3/7] add fallback disable header --- ibis-server/app/dependencies.py | 2 + ibis-server/app/routers/v3/connector.py | 41 +++++- .../v3/connector/postgres/test_fallback_v2.py | 130 ++++++++++++++++++ 3 files changed, 171 insertions(+), 2 deletions(-) diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index 8c5fa6aa7..e24cbc96d 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -4,6 +4,8 @@ from app.model import QueryDTO from app.model.data_source import DataSource +X_WREN_FALLBACK_DISABLE = "x-wren-fallback_disable" + # Rebuild model to validate the dto is correct via validation of the pydantic def verify_query_dto(data_source: DataSource, dto: QueryDTO): diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 68be44cac..52cc0b68d 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -1,3 +1,4 @@ +from distutils.util import strtobool from typing import Annotated from fastapi import APIRouter, Depends, Query, Response @@ -7,7 +8,7 @@ from starlette.datastructures import Headers from app.config import get_config -from app.dependencies import get_wren_headers, verify_query_dto +from app.dependencies import X_WREN_FALLBACK_DISABLE, get_wren_headers, verify_query_dto from app.mdl.core import get_session_context from app.mdl.java_engine import JavaEngineConnector from app.mdl.rewriter import Rewriter @@ -133,6 +134,13 @@ async def query( return response except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( "Failed to execute v3 query, try to fallback to v2: {}\n", str(e) ) @@ -162,6 +170,13 @@ async def dry_plan( try: return await Rewriter(dto.manifest_str, experiment=True).rewrite(dto.sql) except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e) ) @@ -193,6 +208,13 @@ async def dry_plan_for_data_source( dto.manifest_str, data_source=data_source, experiment=True ).rewrite(dto.sql) except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e), @@ -229,6 +251,13 @@ async def validate( await validator.validate(rule_name, dto.parameters, dto.manifest_str) return Response(status_code=204) except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( "Failed to execute v3 validate, try to fallback to v2: {}", str(e), @@ -289,8 +318,16 @@ async def model_substitute( ) return sql except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( - "Failed to execute v3 model-substitute, fallback to v2: {}", str(e) + "Failed to execute v3 model-substitute, try to fallback to v2: {}", + str(e), ) headers = append_fallback_context(headers, span) return await v2.connector.model_substitute( diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py index 2ef62e955..c5bd3517b 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py @@ -3,6 +3,7 @@ import orjson import pytest +from app.dependencies import X_WREN_FALLBACK_DISABLE from tests.routers.v3.connector.postgres.conftest import base_url # It's not a valid manifest for v3. We expect the query to fail and fallback to v2. @@ -44,6 +45,19 @@ async def test_query(client, manifest_str, connection_info): assert len(result["columns"]) == 1 assert len(result["data"]) == 1 + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_query_with_cache(client, manifest_str, connection_info): # First request - should miss cache @@ -79,6 +93,19 @@ async def test_query_with_cache(client, manifest_str, connection_info): assert result1["columns"] == result2["columns"] assert result1["dtypes"] == result2["dtypes"] + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", # Enable cache + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_query_with_cache_override(client, manifest_str, connection_info): # First request - should miss cache then create cache @@ -107,6 +134,19 @@ async def test_query_with_cache_override(client, manifest_str, connection_info): response2.headers["X-Cache-Create-At"] ) + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", # Enable cache + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_query_with_connection_url(client, manifest_str, connection_url): response = await client.post( @@ -119,6 +159,19 @@ async def test_query_with_connection_url(client, manifest_str, connection_url): ) assert response.status_code == 200 + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_query_with_connection_url_and_cache_enable( client, manifest_str, connection_url @@ -153,6 +206,19 @@ async def test_query_with_connection_url_and_cache_enable( assert result1["columns"] == result2["columns"] assert result1["dtypes"] == result2["dtypes"] + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_query_with_connection_url_and_cache_override( client, manifest_str, connection_url @@ -183,6 +249,19 @@ async def test_query_with_connection_url_and_cache_override( response2.headers["X-Cache-Create-At"] ) + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_dry_run(client, manifest_str, connection_info): response = await client.post( @@ -196,6 +275,20 @@ async def test_dry_run(client, manifest_str, connection_info): ) assert response.status_code == 204 + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_dry_plan(client, manifest_str): response = await client.post( @@ -208,6 +301,18 @@ async def test_dry_plan(client, manifest_str): assert response.status_code == 200 assert response.text is not None + response = await client.post( + url="/v3/connector/dry-plan", + json={ + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_dry_plan_for_data_source(client, manifest_str): response = await client.post( @@ -220,6 +325,18 @@ async def test_dry_plan_for_data_source(client, manifest_str): assert response.status_code == 200 assert response.text is not None + response = await client.post( + url=f"{base_url}/dry-plan", + json={ + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_validate(client, manifest_str, connection_info): response = await client.post( @@ -231,3 +348,16 @@ async def test_validate(client, manifest_str, connection_info): }, ) assert response.status_code == 204 + + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "orders", "columnName": "orderkey"}, + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 From bed5fdc92129b41b08518bf300e83f9d94fd46ba Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 29 Apr 2025 17:57:36 +0800 Subject: [PATCH 4/7] avoid using depercated strtobool --- ibis-server/app/routers/v3/connector.py | 19 ++++++++++++------- ibis-server/app/util.py | 4 ++++ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 52cc0b68d..17a07d6f2 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -1,4 +1,3 @@ -from distutils.util import strtobool from typing import Annotated from fastapi import APIRouter, Depends, Query, Response @@ -25,7 +24,13 @@ from app.query_cache import QueryCacheManager from app.routers import v2 from app.routers.v2.connector import get_java_engine_connector, get_query_cache_manager -from app.util import append_fallback_context, build_context, pushdown_limit, to_json +from app.util import ( + append_fallback_context, + build_context, + pushdown_limit, + safe_strtobool, + to_json, +) router = APIRouter(prefix="/connector", tags=["connector"]) tracer = trace.get_tracer(__name__) @@ -136,7 +141,7 @@ async def query( except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) - and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) if is_fallback_disable: raise e @@ -172,7 +177,7 @@ async def dry_plan( except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) - and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) if is_fallback_disable: raise e @@ -210,7 +215,7 @@ async def dry_plan_for_data_source( except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) - and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) if is_fallback_disable: raise e @@ -253,7 +258,7 @@ async def validate( except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) - and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) if is_fallback_disable: raise e @@ -320,7 +325,7 @@ async def model_substitute( except Exception as e: is_fallback_disable = bool( headers.get(X_WREN_FALLBACK_DISABLE) - and strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) ) if is_fallback_disable: raise e diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index 5ff8bad2a..85c1b498a 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -143,3 +143,7 @@ def get_fallback_message( {"datasource": datasource, "mdl_hash": mdl_hash, "sql": sql} ).decode("utf-8") logger.warning("Fallback to v2 {} -- {}\n{}", prefix, message, MIGRATION_MESSAGE) # noqa: PLE1205 + + +def safe_strtobool(val: str) -> bool: + return val.lower() in {"1", "true", "yes", "y"} From 7d970ee82e8aa3c50b7c4396f0bf5a103afb1603 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 29 Apr 2025 18:01:40 +0800 Subject: [PATCH 5/7] add missing fallback argument --- ibis-server/app/routers/v3/connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index 17a07d6f2..41d4d86a5 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -160,6 +160,7 @@ async def query( java_engine_connector=java_engine_connector, query_cache_manager=query_cache_manager, headers=headers, + is_fallback=True, ) From 51ce2c0db810465e5dc3d727e5776b95262a8f31 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Tue, 29 Apr 2025 18:05:06 +0800 Subject: [PATCH 6/7] add missing header --- ibis-server/app/dependencies.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index e24cbc96d..84268e163 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -26,6 +26,8 @@ def get_wren_headers(request: Request) -> Headers: def _filter_headers(header_string: str) -> bool: if header_string.startswith("x-wren-"): return True + elif header_string.startswith("x-user-"): + return True elif header_string == "traceparent": return True elif header_string == "tracestate": From 707fea14d0d7a11b3082ccc7b46e4ba5aa0d6071 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 2 May 2025 21:53:15 +0800 Subject: [PATCH 7/7] address comment --- ibis-server/app/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index 85c1b498a..b5e2ccb55 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -134,13 +134,13 @@ def pushdown_limit(sql: str, limit: int | None) -> str: def get_fallback_message( - logger, prefix: str, datasource: DataSource, mdl_hash: str, sql: str + logger, prefix: str, datasource: DataSource, mdl_base64: str, sql: str ) -> str: if sql is not None: sql = sql.replace("\n", " ") message = orjson.dumps( - {"datasource": datasource, "mdl_hash": mdl_hash, "sql": sql} + {"datasource": datasource, "mdl_base64": mdl_base64, "sql": sql} ).decode("utf-8") logger.warning("Fallback to v2 {} -- {}\n{}", prefix, message, MIGRATION_MESSAGE) # noqa: PLE1205