diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index 4c9592a19..84268e163 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -4,6 +4,8 @@ from app.model import QueryDTO from app.model.data_source import DataSource +X_WREN_FALLBACK_DISABLE = "x-wren-fallback_disable" + # Rebuild model to validate the dto is correct via validation of the pydantic def verify_query_dto(data_source: DataSource, dto: QueryDTO): @@ -11,4 +13,25 @@ def verify_query_dto(data_source: DataSource, dto: QueryDTO): def get_wren_headers(request: Request) -> Headers: - return request.headers + return Headers( + raw=list( + filter( + lambda t: _filter_headers(t[0].decode("latin-1")), + request.headers.raw, + ) + ) + ) + + +def _filter_headers(header_string: str) -> bool: + if header_string.startswith("x-wren-"): + return True + elif header_string.startswith("x-user-"): + return True + elif header_string == "traceparent": + return True + elif header_string == "tracestate": + return True + elif header_string == "sentry-trace": + return True + return False diff --git a/ibis-server/app/routers/v2/connector.py b/ibis-server/app/routers/v2/connector.py index bf62cd479..49d1ffff0 100644 --- a/ibis-server/app/routers/v2/connector.py +++ b/ibis-server/app/routers/v2/connector.py @@ -22,7 +22,7 @@ from app.model.metadata.factory import MetadataFactory from app.model.validator import Validator from app.query_cache import QueryCacheManager -from app.util import build_context, pushdown_limit, to_json +from app.util import build_context, get_fallback_message, pushdown_limit, to_json router = APIRouter(prefix="/connector", tags=["connector"]) tracer = trace.get_tracer(__name__) @@ -43,6 +43,7 @@ def get_query_cache_manager(request: Request) -> QueryCacheManager: description="query the specified data source", ) async def query( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, dto: QueryDTO, dry_run: Annotated[ @@ -58,7 +59,7 @@ async def query( limit: int | None = Query(None, description="limit the number of rows returned"), java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> Response: span_name = f"v2_query_{data_source}" if dry_run: @@ -148,6 +149,11 @@ async def query( response = ORJSONResponse(to_json(result)) response.headers["X-Cache-Hit"] = "false" + if is_fallback: + get_fallback_message( + logger, "query", data_source, dto.manifest_str, dto.sql + ) + return response @@ -157,11 +163,12 @@ async def query( description="validate the specified rule", ) async def validate( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, rule_name: str, dto: ValidateDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> Response: span_name = f"v2_validate_{data_source}" with tracer.start_as_current_span( @@ -176,7 +183,12 @@ async def validate( ), ) await validator.validate(rule_name, dto.parameters, dto.manifest_str) - return Response(status_code=204) + response = Response(status_code=204) + if is_fallback: + get_fallback_message( + logger, "validate", data_source, dto.manifest_str, None + ) + return response @router.post( @@ -230,17 +242,23 @@ def get_db_version(data_source: DataSource, dto: MetadataDTO) -> str: @router.post("/dry-plan", deprecated=True, description="get the planned WrenSQL") async def dry_plan( + headers: Annotated[Headers, Depends(get_wren_headers)], dto: DryPlanDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> str: with tracer.start_as_current_span( name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers) ): - return await Rewriter( + sql = await Rewriter( dto.manifest_str, java_engine_connector=java_engine_connector ).rewrite(dto.sql) + if is_fallback: + get_fallback_message(logger, "dry_plan", None, dto.manifest_str, dto.sql) + + return sql + @router.post( "/{data_source}/dry-plan", @@ -248,20 +266,26 @@ async def dry_plan( description="get the dialect SQL for the specified data source", ) async def dry_plan_for_data_source( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, dto: DryPlanDTO, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), - headers: Annotated[Headers, Depends(get_wren_headers)] = None, + is_fallback: bool | None = None, ) -> str: span_name = f"v2_dry_plan_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) ): - return await Rewriter( + sql = await Rewriter( dto.manifest_str, data_source=data_source, java_engine_connector=java_engine_connector, ).rewrite(dto.sql) + if is_fallback: + get_fallback_message( + logger, "dry_plan", data_source, dto.manifest_str, dto.sql + ) + return sql @router.post( @@ -274,6 +298,7 @@ async def model_substitute( dto: TranspileDTO, headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), + is_fallback: bool | None = None, ) -> str: span_name = f"v2_model_substitute_{data_source}" with tracer.start_as_current_span( @@ -289,4 +314,8 @@ async def model_substitute( java_engine_connector=java_engine_connector, ).rewrite(sql) ) + if is_fallback: + get_fallback_message( + logger, "model_substitute", data_source, dto.manifest_str, dto.sql + ) return sql diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index aef5cc892..41d4d86a5 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -7,7 +7,7 @@ from starlette.datastructures import Headers from app.config import get_config -from app.dependencies import get_wren_headers, verify_query_dto +from app.dependencies import X_WREN_FALLBACK_DISABLE, get_wren_headers, verify_query_dto from app.mdl.core import get_session_context from app.mdl.java_engine import JavaEngineConnector from app.mdl.rewriter import Rewriter @@ -24,14 +24,17 @@ from app.query_cache import QueryCacheManager from app.routers import v2 from app.routers.v2.connector import get_java_engine_connector, get_query_cache_manager -from app.util import build_context, pushdown_limit, to_json +from app.util import ( + append_fallback_context, + build_context, + pushdown_limit, + safe_strtobool, + to_json, +) router = APIRouter(prefix="/connector", tags=["connector"]) tracer = trace.get_tracer(__name__) -MIGRATION_MESSAGE = "Wren engine is migrating to Rust version now. \ - Wren AI team are appreciate if you can provide the error messages and related logs for us." - @router.post( "/{data_source}/query", @@ -136,41 +139,60 @@ async def query( return response except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( - "Failed to execute v3 query, fallback to v2: {}\n" + MIGRATION_MESSAGE, - str(e), + "Failed to execute v3 query, try to fallback to v2: {}\n", str(e) ) + headers = append_fallback_context(headers, span) return await v2.connector.query( - data_source, - dto, - dry_run, - cache_enable, - override_cache, - limit, - java_engine_connector, - query_cache_manager, - headers, + data_source=data_source, + dto=dto, + dry_run=dry_run, + cache_enable=cache_enable, + override_cache=override_cache, + limit=limit, + java_engine_connector=java_engine_connector, + query_cache_manager=query_cache_manager, + headers=headers, + is_fallback=True, ) @router.post("/dry-plan", description="get the planned WrenSQL") async def dry_plan( + headers: Annotated[Headers, Depends(get_wren_headers)], dto: DryPlanDTO, - headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: with tracer.start_as_current_span( name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: return await Rewriter(dto.manifest_str, experiment=True).rewrite(dto.sql) except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( - "Failed to execute v3 dry-plan, fallback to v2: {}\n" - + MIGRATION_MESSAGE, - str(e), + "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e) + ) + headers = append_fallback_context(headers, span) + return await v2.connector.dry_plan( + dto=dto, + java_engine_connector=java_engine_connector, + headers=headers, + is_fallback=True, ) - return await v2.connector.dry_plan(dto, java_engine_connector, headers) @router.post( @@ -178,27 +200,38 @@ async def dry_plan( description="get the dialect SQL for the specified data source", ) async def dry_plan_for_data_source( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, dto: DryPlanDTO, - headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> str: span_name = f"v3_dry_plan_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: return await Rewriter( dto.manifest_str, data_source=data_source, experiment=True ).rewrite(dto.sql) except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( - "Failed to execute v3 dry-plan, fallback to v2: {}\n" - + MIGRATION_MESSAGE, + "Failed to execute v3 dry-plan, try to fallback to v2: {}", str(e), ) + headers = append_fallback_context(headers, span) return await v2.connector.dry_plan_for_data_source( - data_source, dto, java_engine_connector, headers + data_source=data_source, + dto=dto, + java_engine_connector=java_engine_connector, + headers=headers, + is_fallback=True, ) @@ -206,16 +239,16 @@ async def dry_plan_for_data_source( "/{data_source}/validate/{rule_name}", description="validate the specified rule" ) async def validate( + headers: Annotated[Headers, Depends(get_wren_headers)], data_source: DataSource, rule_name: str, dto: ValidateDTO, - headers: Annotated[Headers, Depends(get_wren_headers)] = None, java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector), ) -> Response: span_name = f"v3_validate_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: validator = Validator( Connector(data_source, dto.connection_info), @@ -224,13 +257,25 @@ async def validate( await validator.validate(rule_name, dto.parameters, dto.manifest_str) return Response(status_code=204) except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( - "Failed to execute v3 validate, fallback to v2: {}\n" - + MIGRATION_MESSAGE, + "Failed to execute v3 validate, try to fallback to v2: {}", str(e), ) + headers = append_fallback_context(headers, span) return await v2.connector.validate( - data_source, rule_name, dto, java_engine_connector, headers + data_source=data_source, + rule_name=rule_name, + dto=dto, + java_engine_connector=java_engine_connector, + headers=headers, + is_fallback=True, ) @@ -265,7 +310,7 @@ async def model_substitute( span_name = f"v3_model-substitute_{data_source}" with tracer.start_as_current_span( name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) - ): + ) as span: try: sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute( dto.sql @@ -279,9 +324,22 @@ async def model_substitute( ) return sql except Exception as e: + is_fallback_disable = bool( + headers.get(X_WREN_FALLBACK_DISABLE) + and safe_strtobool(headers.get(X_WREN_FALLBACK_DISABLE, "false")) + ) + if is_fallback_disable: + raise e + logger.warning( - "Failed to execute v3 model-substitute, fallback to v2: {}", str(e) + "Failed to execute v3 model-substitute, try to fallback to v2: {}", + str(e), ) + headers = append_fallback_context(headers, span) return await v2.connector.model_substitute( - data_source, dto, headers, java_engine_connector + data_source=data_source, + dto=dto, + headers=headers, + java_engine_connector=java_engine_connector, + is_fallback=True, ) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index 9addf0759..b5e2ccb55 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -7,13 +7,26 @@ import wren_core from fastapi import Header from opentelemetry import trace +from opentelemetry.baggage.propagation import W3CBaggagePropagator from opentelemetry.context import Context from opentelemetry.propagate import extract +from opentelemetry.trace import ( + NonRecordingSpan, + set_span_in_context, +) +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pandas.core.dtypes.common import is_datetime64_any_dtype +from starlette.datastructures import Headers + +from app.model.data_source import DataSource tracer = trace.get_tracer(__name__) +MIGRATION_MESSAGE = "Wren engine is migrating to Rust version now. \ + Wren AI team are appreciate if you can provide the error messages and related logs for us." + + @tracer.start_as_current_span("base64_to_dict", kind=trace.SpanKind.INTERNAL) def base64_to_dict(base64_str: str) -> dict: return orjson.loads(base64.b64decode(base64_str).decode("utf-8")) @@ -101,7 +114,36 @@ def build_context(headers: Header) -> Context: return extract(headers) +def append_fallback_context(headers: Header, span: trace.Span) -> Headers: + if headers is None: + headers = {} + else: + headers = dict(headers) + span = NonRecordingSpan(span.get_span_context()) + context = set_span_in_context(span) + # https://opentelemetry.io/docs/languages/python/propagation/ + W3CBaggagePropagator().inject(headers, context) + TraceContextTextMapPropagator().inject(headers, context) + return Headers(headers) + + @tracer.start_as_current_span("pushdown_limit", kind=trace.SpanKind.INTERNAL) def pushdown_limit(sql: str, limit: int | None) -> str: ctx = wren_core.SessionContext() return ctx.pushdown_limit(sql, limit) + + +def get_fallback_message( + logger, prefix: str, datasource: DataSource, mdl_base64: str, sql: str +) -> str: + if sql is not None: + sql = sql.replace("\n", " ") + + message = orjson.dumps( + {"datasource": datasource, "mdl_base64": mdl_base64, "sql": sql} + ).decode("utf-8") + logger.warning("Fallback to v2 {} -- {}\n{}", prefix, message, MIGRATION_MESSAGE) # noqa: PLE1205 + + +def safe_strtobool(val: str) -> bool: + return val.lower() in {"1", "true", "yes", "y"} diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py index 2ef62e955..c5bd3517b 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_fallback_v2.py @@ -3,6 +3,7 @@ import orjson import pytest +from app.dependencies import X_WREN_FALLBACK_DISABLE from tests.routers.v3.connector.postgres.conftest import base_url # It's not a valid manifest for v3. We expect the query to fail and fallback to v2. @@ -44,6 +45,19 @@ async def test_query(client, manifest_str, connection_info): assert len(result["columns"]) == 1 assert len(result["data"]) == 1 + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_query_with_cache(client, manifest_str, connection_info): # First request - should miss cache @@ -79,6 +93,19 @@ async def test_query_with_cache(client, manifest_str, connection_info): assert result1["columns"] == result2["columns"] assert result1["dtypes"] == result2["dtypes"] + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", # Enable cache + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_query_with_cache_override(client, manifest_str, connection_info): # First request - should miss cache then create cache @@ -107,6 +134,19 @@ async def test_query_with_cache_override(client, manifest_str, connection_info): response2.headers["X-Cache-Create-At"] ) + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", # Enable cache + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_query_with_connection_url(client, manifest_str, connection_url): response = await client.post( @@ -119,6 +159,19 @@ async def test_query_with_connection_url(client, manifest_str, connection_url): ) assert response.status_code == 200 + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_query_with_connection_url_and_cache_enable( client, manifest_str, connection_url @@ -153,6 +206,19 @@ async def test_query_with_connection_url_and_cache_enable( assert result1["columns"] == result2["columns"] assert result1["dtypes"] == result2["dtypes"] + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_query_with_connection_url_and_cache_override( client, manifest_str, connection_url @@ -183,6 +249,19 @@ async def test_query_with_connection_url_and_cache_override( response2.headers["X-Cache-Create-At"] ) + response1 = await client.post( + url=f"{base_url}/query?cacheEnable=true", + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response1.status_code == 422 + async def test_dry_run(client, manifest_str, connection_info): response = await client.post( @@ -196,6 +275,20 @@ async def test_dry_run(client, manifest_str, connection_info): ) assert response.status_code == 204 + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_dry_plan(client, manifest_str): response = await client.post( @@ -208,6 +301,18 @@ async def test_dry_plan(client, manifest_str): assert response.status_code == 200 assert response.text is not None + response = await client.post( + url="/v3/connector/dry-plan", + json={ + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_dry_plan_for_data_source(client, manifest_str): response = await client.post( @@ -220,6 +325,18 @@ async def test_dry_plan_for_data_source(client, manifest_str): assert response.status_code == 200 assert response.text is not None + response = await client.post( + url=f"{base_url}/dry-plan", + json={ + "manifestStr": manifest_str, + "sql": "SELECT orderkey FROM orders LIMIT 1", + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422 + async def test_validate(client, manifest_str, connection_info): response = await client.post( @@ -231,3 +348,16 @@ async def test_validate(client, manifest_str, connection_info): }, ) assert response.status_code == 204 + + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "orders", "columnName": "orderkey"}, + }, + headers={ + X_WREN_FALLBACK_DISABLE: "true", + }, + ) + assert response.status_code == 422