Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,34 @@
from app.model import QueryDTO
from app.model.data_source import DataSource

X_WREN_FALLBACK_DISABLE = "x-wren-fallback_disable"


# Rebuild model to validate the dto is correct via validation of the pydantic
def verify_query_dto(data_source: DataSource, dto: QueryDTO):
data_source.get_dto_type()(**dto.model_dump(by_alias=True))


def get_wren_headers(request: Request) -> Headers:
return request.headers
return Headers(
raw=list(
filter(
lambda t: _filter_headers(t[0].decode("latin-1")),
request.headers.raw,
)
)
)


def _filter_headers(header_string: str) -> bool:
if header_string.startswith("x-wren-"):
return True
elif header_string.startswith("x-user-"):
return True
elif header_string == "traceparent":
return True
elif header_string == "tracestate":
return True
elif header_string == "sentry-trace":
return True
return False
45 changes: 37 additions & 8 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from app.model.metadata.factory import MetadataFactory
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
from app.util import build_context, pushdown_limit, to_json
from app.util import build_context, get_fallback_message, pushdown_limit, to_json

router = APIRouter(prefix="/connector", tags=["connector"])
tracer = trace.get_tracer(__name__)
Expand All @@ -43,6 +43,7 @@ def get_query_cache_manager(request: Request) -> QueryCacheManager:
description="query the specified data source",
)
async def query(
headers: Annotated[Headers, Depends(get_wren_headers)],
data_source: DataSource,
dto: QueryDTO,
dry_run: Annotated[
Expand All @@ -58,7 +59,7 @@ async def query(
limit: int | None = Query(None, description="limit the number of rows returned"),
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
is_fallback: bool | None = None,
) -> Response:
span_name = f"v2_query_{data_source}"
if dry_run:
Expand Down Expand Up @@ -148,6 +149,11 @@ async def query(
response = ORJSONResponse(to_json(result))
response.headers["X-Cache-Hit"] = "false"

if is_fallback:
get_fallback_message(
logger, "query", data_source, dto.manifest_str, dto.sql
)

return response


Expand All @@ -157,11 +163,12 @@ async def query(
description="validate the specified rule",
)
async def validate(
headers: Annotated[Headers, Depends(get_wren_headers)],
data_source: DataSource,
rule_name: str,
dto: ValidateDTO,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
is_fallback: bool | None = None,
) -> Response:
span_name = f"v2_validate_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -176,7 +183,12 @@ async def validate(
),
)
await validator.validate(rule_name, dto.parameters, dto.manifest_str)
return Response(status_code=204)
response = Response(status_code=204)
if is_fallback:
get_fallback_message(
logger, "validate", data_source, dto.manifest_str, None
)
return response


@router.post(
Expand Down Expand Up @@ -230,38 +242,50 @@ def get_db_version(data_source: DataSource, dto: MetadataDTO) -> str:

@router.post("/dry-plan", deprecated=True, description="get the planned WrenSQL")
async def dry_plan(
headers: Annotated[Headers, Depends(get_wren_headers)],
dto: DryPlanDTO,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
is_fallback: bool | None = None,
) -> str:
with tracer.start_as_current_span(
name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers)
):
return await Rewriter(
sql = await Rewriter(
dto.manifest_str, java_engine_connector=java_engine_connector
).rewrite(dto.sql)

if is_fallback:
get_fallback_message(logger, "dry_plan", None, dto.manifest_str, dto.sql)

return sql


@router.post(
"/{data_source}/dry-plan",
deprecated=True,
description="get the dialect SQL for the specified data source",
)
async def dry_plan_for_data_source(
headers: Annotated[Headers, Depends(get_wren_headers)],
data_source: DataSource,
dto: DryPlanDTO,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
is_fallback: bool | None = None,
) -> str:
span_name = f"v2_dry_plan_{data_source}"
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
return await Rewriter(
sql = await Rewriter(
dto.manifest_str,
data_source=data_source,
java_engine_connector=java_engine_connector,
).rewrite(dto.sql)
if is_fallback:
get_fallback_message(
logger, "dry_plan", data_source, dto.manifest_str, dto.sql
)
return sql


@router.post(
Expand All @@ -274,6 +298,7 @@ async def model_substitute(
dto: TranspileDTO,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
is_fallback: bool | None = None,
) -> str:
span_name = f"v2_model_substitute_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -289,4 +314,8 @@ async def model_substitute(
java_engine_connector=java_engine_connector,
).rewrite(sql)
)
if is_fallback:
get_fallback_message(
logger, "model_substitute", data_source, dto.manifest_str, dto.sql
)
return sql
Loading