Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
X_CACHE_CREATE_AT = "X-Cache-Create-At"
X_CACHE_OVERRIDE = "X-Cache-Override"
X_CACHE_OVERRIDE_AT = "X-Cache-Override-At"
X_CORRELATION_ID = "X-Correlation-ID"


# Validate the dto by building the specific connection info from the data source
Expand Down
43 changes: 33 additions & 10 deletions ibis-server/app/main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from datetime import datetime
from typing import TypedDict
from uuid import uuid4

from asgi_correlation_id import CorrelationIdMiddleware
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.responses import ORJSONResponse, RedirectResponse
from loguru import logger
from starlette.responses import PlainTextResponse

from app.config import get_config
from app.dependencies import X_CORRELATION_ID
from app.mdl.java_engine import JavaEngineConnector
from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware
from app.model import ConfigModel, CustomHttpError
from app.model import ConfigModel
from app.model.error import ErrorCode, ErrorResponse, WrenError
from app.query_cache import QueryCacheManager
from app.routers import v2, v3

Expand Down Expand Up @@ -44,7 +46,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[State]:
app.add_middleware(ProcessTimeMiddleware)
app.add_middleware(
CorrelationIdMiddleware,
header_name="X-Correlation-ID",
header_name=X_CORRELATION_ID,
generator=lambda: str(uuid4()),
)

Expand Down Expand Up @@ -74,17 +76,38 @@ def update_config(config_model: ConfigModel):
# In Starlette, the Exception is special and is not included in normal exception handlers.
@app.exception_handler(Exception)
def exception_handler(request, exc: Exception):
return PlainTextResponse(str(exc), status_code=500)
return ORJSONResponse(
status_code=500,
content=ErrorResponse(
error_code=ErrorCode.GENERIC_INTERNAL_ERROR.name,
message=str(exc),
timestamp=datetime.now().isoformat(),
correlation_id=request.headers.get(X_CORRELATION_ID),
).model_dump(by_alias=True),
)


# In Starlette, the exceptions other than the Exception are not raised when call_next in the middleware.
@app.exception_handler(CustomHttpError)
def custom_http_error_handler(request, exc: CustomHttpError):
with logger.contextualize(correlation_id=request.headers.get("X-Correlation-ID")):
@app.exception_handler(WrenError)
def wren_error_handler(request, exc: WrenError):
with logger.contextualize(correlation_id=request.headers.get(X_CORRELATION_ID)):
logger.opt(exception=exc).error("Request failed")
return PlainTextResponse(str(exc), status_code=exc.status_code)
return ORJSONResponse(
status_code=exc.get_http_status_code(),
content=exc.get_response(
correlation_id=request.headers.get(X_CORRELATION_ID)
).model_dump(by_alias=True),
)


@app.exception_handler(NotImplementedError)
def not_implemented_error_handler(request, exc: NotImplementedError):
return PlainTextResponse(str(exc), status_code=501)
return ORJSONResponse(
status_code=501,
content=ErrorResponse(
error_code=ErrorCode.NOT_IMPLEMENTED.name,
message=str(exc),
timestamp=datetime.now().isoformat(),
correlation_id=request.headers.get(X_CORRELATION_ID),
).model_dump(by_alias=True),
)
19 changes: 9 additions & 10 deletions ibis-server/app/mdl/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import orjson

from app.config import get_config
from app.model import UnprocessableEntityError
from app.model.error import ErrorCode, WrenError

wren_engine_endpoint = get_config().wren_engine_endpoint

Expand All @@ -17,9 +17,11 @@ def analyze(manifest_str: str, sql: str) -> list[dict]:
)
return r.raise_for_status().json()
except httpx.ConnectError as e:
raise ConnectionError(f"Can not connect to Java Engine: {e}") from e
raise WrenError(
ErrorCode.LEGACY_ENGINE_ERROR, f"Can not connect to Java Engine: {e}"
) from e
except httpx.HTTPStatusError as e:
raise AnalyzeError(e.response.text)
raise WrenError(ErrorCode.GENERIC_USER_ERROR, e.response.text)


def analyze_batch(manifest_str: str, sqls: list[str]) -> list[list[dict]]:
Expand All @@ -32,11 +34,8 @@ def analyze_batch(manifest_str: str, sqls: list[str]) -> list[list[dict]]:
)
return r.raise_for_status().json()
except httpx.ConnectError as e:
raise ConnectionError(f"Can not connect to Java Engine: {e}") from e
raise WrenError(
ErrorCode.LEGACY_ENGINE_ERROR, f"Can not connect to Java Engine: {e}"
) from e
except httpx.HTTPStatusError as e:
raise AnalyzeError(e.response.text)


class AnalyzeError(UnprocessableEntityError):
def __init__(self, message: str):
super().__init__(message)
raise WrenError(ErrorCode.GENERIC_USER_ERROR, e.response.text)
7 changes: 5 additions & 2 deletions ibis-server/app/mdl/java_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from orjson import orjson

from app.config import get_config
from app.model.error import ErrorCode, ErrorPhase, WrenError

wren_engine_endpoint = get_config().wren_engine_endpoint

Expand All @@ -27,8 +28,10 @@ def __init__(self, end_point: str | None = None):

async def dry_plan(self, manifest_str: str, sql: str):
if self.client is None:
raise ValueError(
"WREN_ENGINE_ENDPOINT is not set. Cannot call dry_plan without a valid endpoint."
raise WrenError(
ErrorCode.GENERIC_INTERNAL_ERROR,
"WREN_ENGINE_ENDPOINT is not set. Cannot call dry_plan without a valid endpoint.",
phase=ErrorPhase.SQL_PLANNING,
)

r = await self.client.request(
Expand Down
45 changes: 25 additions & 20 deletions ibis-server/app/mdl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
to_json_base64,
)
from app.mdl.java_engine import JavaEngineConnector
from app.model import InternalServerError, UnprocessableEntityError
from app.model.data_source import DataSource
from app.model.error import PLANNED_SQL, ErrorCode, ErrorPhase, WrenError

# To register custom dialects from ibis library for sqlglot
importlib.import_module("ibis.backends.sql.dialects")
Expand Down Expand Up @@ -47,9 +47,17 @@ def __init__(

@tracer.start_as_current_span("transpile", kind=trace.SpanKind.INTERNAL)
def _transpile(self, planned_sql: str) -> str:
read = self._get_read_dialect(self.experiment)
write = self._get_write_dialect(self.data_source)
return sqlglot.transpile(planned_sql, read=read, write=write)[0]
try:
read = self._get_read_dialect(self.experiment)
write = self._get_write_dialect(self.data_source)
return sqlglot.transpile(planned_sql, read=read, write=write)[0]
except Exception as e:
raise WrenError(
ErrorCode.SQLGLOT_ERROR,
str(e),
phase=ErrorPhase.SQL_TRANSPILE,
metadata={PLANNED_SQL: planned_sql},
)

@tracer.start_as_current_span("rewrite", kind=trace.SpanKind.INTERNAL)
async def rewrite(self, sql: str) -> str:
Expand Down Expand Up @@ -102,11 +110,18 @@ async def rewrite(
try:
return await self.java_engine_connector.dry_plan(manifest_str, sql)
except httpx.ConnectError as e:
raise WrenEngineError(f"Can not connect to Java Engine: {e}")
raise WrenError(
ErrorCode.LEGACY_ENGINE_ERROR, f"Can not connect to Java Engine: {e}"
)
except httpx.TimeoutException as e:
raise WrenEngineError(f"Timeout when connecting to Java Engine: {e}")
raise WrenError(
ErrorCode.LEGACY_ENGINE_ERROR,
f"Timeout when connecting to Java Engine: {e}",
)
except httpx.HTTPStatusError as e:
raise RewriteError(e.response.text)
raise WrenError(
ErrorCode.INVALID_SQL, e.response.text, ErrorPhase.SQL_PLANNING
)

@staticmethod
def handle_extract_exception(e: Exception):
Expand All @@ -131,7 +146,7 @@ async def rewrite(
sql,
)
except Exception as e:
raise RewriteError(str(e))
raise WrenError(ErrorCode.INVALID_SQL, str(e), ErrorPhase.SQL_PLANNING)

@tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL)
def rewrite_sync(
Expand All @@ -144,7 +159,7 @@ def rewrite_sync(
)
return session_context.transform_sql(sql)
except Exception as e:
raise RewriteError(str(e))
raise WrenError(ErrorCode.INVALID_SQL, str(e), ErrorPhase.SQL_PLANNING)

def get_session_properties(self, properties: dict) -> frozenset | None:
if properties is None:
Expand All @@ -162,14 +177,4 @@ def get_session_properties(self, properties: dict) -> frozenset | None:

@staticmethod
def handle_extract_exception(e: Exception):
raise RewriteError(str(e))


class RewriteError(UnprocessableEntityError):
def __init__(self, message: str):
super().__init__(message)


class WrenEngineError(InternalServerError):
def __init__(self, message: str):
super().__init__(message)
raise WrenError(ErrorCode.INVALID_MDL, str(e), ErrorPhase.MDL_EXTRACTION)
19 changes: 10 additions & 9 deletions ibis-server/app/mdl/substitute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from sqlglot import exp, parse_one
from sqlglot.optimizer.scope import build_scope

from app.model import UnprocessableEntityError
from app.model.data_source import DataSource
from app.model.error import ErrorCode, ErrorPhase, WrenError
from app.util import base64_to_dict

tracer = trace.get_tracer(__name__)
Expand Down Expand Up @@ -40,12 +40,18 @@ def substitute(self, sql: str, write: str | None = None) -> str:
# if model name is ambiguous, raise an error
duplicate_keys = get_case_insensitive_duplicate_keys(self.model_dict)
if model is not None and key.lower() in duplicate_keys:
raise SubstituteError(
f"Ambiguous model: found multiple matches for {source}"
raise WrenError(
ErrorCode.GENERIC_USER_ERROR,
f"Ambiguous model: found multiple matches for {source}",
phase=ErrorPhase.SQL_SUBSTITUTE,
)

if model is None:
raise SubstituteError(f"Model not found: {source}")
raise WrenError(
ErrorCode.NOT_FOUND,
f"Model not found: {source}",
phase=ErrorPhase.SQL_SUBSTITUTE,
)

source.replace(
exp.Table(
Expand Down Expand Up @@ -129,8 +135,3 @@ def get_case_insensitive_duplicate_keys(d):

duplicates = [key for keys in key_map.values() if len(keys) > 1 for key in keys]
return duplicates


class SubstituteError(UnprocessableEntityError):
def __init__(self, message: str):
super().__init__(message)
35 changes: 0 additions & 35 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from __future__ import annotations

from abc import ABC
from enum import Enum
from typing import Annotated, Any, Literal, Union

from pydantic import BaseModel, Field, SecretStr
from starlette.status import (
HTTP_404_NOT_FOUND,
HTTP_422_UNPROCESSABLE_ENTITY,
HTTP_500_INTERNAL_SERVER_ERROR,
)

manifest_str_field = Field(alias="manifestStr", description="Base64 manifest")
connection_info_field = Field(alias="connectionInfo")
Expand Down Expand Up @@ -526,35 +520,6 @@ class ConfigModel(BaseModel):
diagnose: bool


class UnknownIbisError(Exception):
def __init__(self, message):
self.message = f"Unknown ibis error: {message!s}"


class CustomHttpError(ABC, Exception):
status_code: int


class InternalServerError(CustomHttpError):
status_code = HTTP_500_INTERNAL_SERVER_ERROR


class UnprocessableEntityError(CustomHttpError):
status_code = HTTP_422_UNPROCESSABLE_ENTITY


class DatabaseTimeoutError(CustomHttpError):
status_code = 504

def __init__(self, message: str):
super().__init__(message)
self.message = f"Database timeout error: {message!s}.\nIt seems your database is not responding or the query is taking too long to execute. Please check your database status and query performance."


class NotFoundError(CustomHttpError):
status_code = HTTP_404_NOT_FOUND


class SSLMode(str, Enum):
DISABLED = "disabled"
ENABLED = "enabled"
Expand Down
Loading
Loading