Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from fastapi import Request
from starlette.datastructures import Headers

from app.model import QueryDTO
from app.model.data_source import DataSource


# Rebuild model to validate the dto is correct via validation of the pydantic
def verify_query_dto(data_source: DataSource, dto: QueryDTO):
data_source.get_dto_type()(**dto.model_dump(by_alias=True))


def get_wren_headers(request: Request) -> Headers:
return request.headers
28 changes: 24 additions & 4 deletions ibis-server/app/mdl/substitute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@


class ModelSubstitute:
def __init__(self, data_source: DataSource, manifest_str: str):
def __init__(self, data_source: DataSource, manifest_str: str, headers=None):
self.data_source = data_source
self.manifest = base64_to_dict(manifest_str)
self.model_dict = self._build_model_dict(self.manifest["models"])
self.headers = dict(headers) if headers else None

@tracer.start_as_current_span("substitute", kind=trace.SpanKind.INTERNAL)
def substitute(self, sql: str, write: str | None = None) -> str:
Expand All @@ -39,13 +40,32 @@ def substitute(self, sql: str, write: str | None = None) -> str:
def _build_model_dict(models) -> dict:
def key(model):
table_ref = model["tableReference"]
return f"{table_ref.get('catalog', '')}.{table_ref.get('schema', '')}.{table_ref.get('table', '')}"

# fully qualified catalog.schema.table
if table_ref.get("catalog") and table_ref.get("schema"):
return f"{table_ref.get('catalog', '')}.{table_ref.get('schema', '')}.{table_ref.get('table', '')}"
# schema.table
elif table_ref.get("schema"):
return f"{table_ref.get('schema', '')}.{table_ref.get('table', '')}"
# table
else:
return table_ref.get("table", "")

return {key(model): model for model in models if "tableReference" in model}

def _find_model(self, source: exp.Table) -> dict | None:
catalog = source.catalog or ""
schema = source.db or ""
# Determine catalog
if source.catalog:
catalog = source.catalog
else:
catalog = self.headers.get("x-user-catalog", "") if self.headers else ""

# Determine schema
if source.db:
schema = source.db
else:
schema = self.headers.get("x-user-schema", "") if self.headers else ""

table = source.name
return self.model_dict.get(f"{catalog}.{schema}.{table}", None)

Expand Down
21 changes: 11 additions & 10 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Annotated

from fastapi import APIRouter, Depends, Header, Query, Request, Response
from fastapi import APIRouter, Depends, Query, Request, Response
from fastapi.responses import ORJSONResponse
from loguru import logger
from opentelemetry import trace
from starlette.datastructures import Headers

from app.dependencies import verify_query_dto
from app.dependencies import get_wren_headers, verify_query_dto
from app.mdl.java_engine import JavaEngineConnector
from app.mdl.rewriter import Rewriter
from app.mdl.substitute import ModelSubstitute
Expand Down Expand Up @@ -57,7 +58,7 @@ async def query(
limit: int | None = Query(None, description="limit the number of rows returned"),
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> Response:
span_name = f"v2_query_{data_source}"
if dry_run:
Expand Down Expand Up @@ -160,7 +161,7 @@ async def validate(
rule_name: str,
dto: ValidateDTO,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> Response:
span_name = f"v2_validate_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -187,7 +188,7 @@ async def validate(
def get_table_list(
data_source: DataSource,
dto: MetadataDTO,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> list[Table]:
span_name = f"v2_metadata_tables_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -207,7 +208,7 @@ def get_table_list(
def get_constraints(
data_source: DataSource,
dto: MetadataDTO,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> list[Constraint]:
span_name = f"v2_metadata_constraints_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -231,7 +232,7 @@ def get_db_version(data_source: DataSource, dto: MetadataDTO) -> str:
async def dry_plan(
dto: DryPlanDTO,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> str:
with tracer.start_as_current_span(
name="dry_plan", kind=trace.SpanKind.SERVER, context=build_context(headers)
Expand All @@ -250,7 +251,7 @@ async def dry_plan_for_data_source(
data_source: DataSource,
dto: DryPlanDTO,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> str:
span_name = f"v2_dry_plan_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -271,14 +272,14 @@ async def dry_plan_for_data_source(
async def model_substitute(
data_source: DataSource,
dto: TranspileDTO,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
headers: Annotated[str | None, Header()] = None,
) -> str:
span_name = f"v2_model_substitute_{data_source}"
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
sql = ModelSubstitute(data_source, dto.manifest_str).substitute(
sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute(
dto.sql, write="trino"
)
Connector(data_source, dto.connection_info).dry_run(
Expand Down
23 changes: 13 additions & 10 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Annotated

from fastapi import APIRouter, Depends, Header, Query, Response
from fastapi import APIRouter, Depends, Query, Response
from fastapi.responses import ORJSONResponse
from loguru import logger
from opentelemetry import trace
from starlette.datastructures import Headers

from app.config import get_config
from app.dependencies import verify_query_dto
from app.dependencies import get_wren_headers, verify_query_dto
from app.mdl.core import get_session_context
from app.mdl.java_engine import JavaEngineConnector
from app.mdl.rewriter import Rewriter
Expand Down Expand Up @@ -51,7 +52,7 @@ async def query(
bool, Query(alias="overrideCache", description="ovrride the exist cache")
] = False,
limit: int | None = Query(None, description="limit the number of rows returned"),
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
query_cache_manager: QueryCacheManager = Depends(get_query_cache_manager),
) -> Response:
Expand Down Expand Up @@ -155,7 +156,7 @@ async def query(
@router.post("/dry-plan", description="get the planned WrenSQL")
async def dry_plan(
dto: DryPlanDTO,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
) -> str:
with tracer.start_as_current_span(
Expand All @@ -179,7 +180,7 @@ async def dry_plan(
async def dry_plan_for_data_source(
data_source: DataSource,
dto: DryPlanDTO,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
) -> str:
span_name = f"v3_dry_plan_{data_source}"
Expand Down Expand Up @@ -208,7 +209,7 @@ async def validate(
data_source: DataSource,
rule_name: str,
dto: ValidateDTO,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
) -> Response:
span_name = f"v3_validate_{data_source}"
Expand Down Expand Up @@ -239,7 +240,7 @@ async def validate(
)
def functions(
data_source: DataSource,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> Response:
span_name = f"v3_functions_{data_source}"
with tracer.start_as_current_span(
Expand All @@ -258,15 +259,17 @@ def functions(
async def model_substitute(
data_source: DataSource,
dto: TranspileDTO,
headers: Annotated[str | None, Header()] = None,
headers: Annotated[Headers, Depends(get_wren_headers)],
java_engine_connector: JavaEngineConnector = Depends(get_java_engine_connector),
) -> str:
span_name = f"v3_model-substitute_{data_source}"
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
):
try:
sql = ModelSubstitute(data_source, dto.manifest_str).substitute(dto.sql)
sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute(
dto.sql
)
Connector(data_source, dto.connection_info).dry_run(
await Rewriter(
dto.manifest_str,
Expand All @@ -280,5 +283,5 @@ async def model_substitute(
"Failed to execute v3 model-substitute, fallback to v2: {}", str(e)
)
return await v2.connector.model_substitute(
data_source, dto, java_engine_connector, headers
data_source, dto, headers, java_engine_connector
)
Loading