diff --git a/ibis-server/app/dependencies.py b/ibis-server/app/dependencies.py index dae78dbc4..5ce90b97f 100644 --- a/ibis-server/app/dependencies.py +++ b/ibis-server/app/dependencies.py @@ -13,6 +13,7 @@ X_CACHE_CREATE_AT = "X-Cache-Create-At" X_CACHE_OVERRIDE = "X-Cache-Override" X_CACHE_OVERRIDE_AT = "X-Cache-Override-At" +X_CORRELATION_ID = "X-Correlation-ID" # Validate the dto by building the specific connection info from the data source diff --git a/ibis-server/app/main.py b/ibis-server/app/main.py index 856c17633..7c2cae406 100644 --- a/ibis-server/app/main.py +++ b/ibis-server/app/main.py @@ -1,18 +1,20 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager +from datetime import datetime from typing import TypedDict from uuid import uuid4 from asgi_correlation_id import CorrelationIdMiddleware from fastapi import FastAPI -from fastapi.responses import RedirectResponse +from fastapi.responses import ORJSONResponse, RedirectResponse from loguru import logger -from starlette.responses import PlainTextResponse from app.config import get_config +from app.dependencies import X_CORRELATION_ID from app.mdl.java_engine import JavaEngineConnector from app.middleware import ProcessTimeMiddleware, RequestLogMiddleware -from app.model import ConfigModel, CustomHttpError +from app.model import ConfigModel +from app.model.error import ErrorCode, ErrorResponse, WrenError from app.query_cache import QueryCacheManager from app.routers import v2, v3 @@ -44,7 +46,7 @@ async def lifespan(app: FastAPI) -> AsyncIterator[State]: app.add_middleware(ProcessTimeMiddleware) app.add_middleware( CorrelationIdMiddleware, - header_name="X-Correlation-ID", + header_name=X_CORRELATION_ID, generator=lambda: str(uuid4()), ) @@ -74,17 +76,38 @@ def update_config(config_model: ConfigModel): # In Starlette, the Exception is special and is not included in normal exception handlers. @app.exception_handler(Exception) def exception_handler(request, exc: Exception): - return PlainTextResponse(str(exc), status_code=500) + return ORJSONResponse( + status_code=500, + content=ErrorResponse( + error_code=ErrorCode.GENERIC_INTERNAL_ERROR.name, + message=str(exc), + timestamp=datetime.now().isoformat(), + correlation_id=request.headers.get(X_CORRELATION_ID), + ).model_dump(by_alias=True), + ) # In Starlette, the exceptions other than the Exception are not raised when call_next in the middleware. -@app.exception_handler(CustomHttpError) -def custom_http_error_handler(request, exc: CustomHttpError): - with logger.contextualize(correlation_id=request.headers.get("X-Correlation-ID")): +@app.exception_handler(WrenError) +def wren_error_handler(request, exc: WrenError): + with logger.contextualize(correlation_id=request.headers.get(X_CORRELATION_ID)): logger.opt(exception=exc).error("Request failed") - return PlainTextResponse(str(exc), status_code=exc.status_code) + return ORJSONResponse( + status_code=exc.get_http_status_code(), + content=exc.get_response( + correlation_id=request.headers.get(X_CORRELATION_ID) + ).model_dump(by_alias=True), + ) @app.exception_handler(NotImplementedError) def not_implemented_error_handler(request, exc: NotImplementedError): - return PlainTextResponse(str(exc), status_code=501) + return ORJSONResponse( + status_code=501, + content=ErrorResponse( + error_code=ErrorCode.NOT_IMPLEMENTED.name, + message=str(exc), + timestamp=datetime.now().isoformat(), + correlation_id=request.headers.get(X_CORRELATION_ID), + ).model_dump(by_alias=True), + ) diff --git a/ibis-server/app/mdl/analyzer.py b/ibis-server/app/mdl/analyzer.py index 13f144103..ea68a3bf9 100644 --- a/ibis-server/app/mdl/analyzer.py +++ b/ibis-server/app/mdl/analyzer.py @@ -2,7 +2,7 @@ import orjson from app.config import get_config -from app.model import UnprocessableEntityError +from app.model.error import ErrorCode, WrenError wren_engine_endpoint = get_config().wren_engine_endpoint @@ -17,9 +17,11 @@ def analyze(manifest_str: str, sql: str) -> list[dict]: ) return r.raise_for_status().json() except httpx.ConnectError as e: - raise ConnectionError(f"Can not connect to Java Engine: {e}") from e + raise WrenError( + ErrorCode.LEGACY_ENGINE_ERROR, f"Can not connect to Java Engine: {e}" + ) from e except httpx.HTTPStatusError as e: - raise AnalyzeError(e.response.text) + raise WrenError(ErrorCode.GENERIC_USER_ERROR, e.response.text) def analyze_batch(manifest_str: str, sqls: list[str]) -> list[list[dict]]: @@ -32,11 +34,8 @@ def analyze_batch(manifest_str: str, sqls: list[str]) -> list[list[dict]]: ) return r.raise_for_status().json() except httpx.ConnectError as e: - raise ConnectionError(f"Can not connect to Java Engine: {e}") from e + raise WrenError( + ErrorCode.LEGACY_ENGINE_ERROR, f"Can not connect to Java Engine: {e}" + ) from e except httpx.HTTPStatusError as e: - raise AnalyzeError(e.response.text) - - -class AnalyzeError(UnprocessableEntityError): - def __init__(self, message: str): - super().__init__(message) + raise WrenError(ErrorCode.GENERIC_USER_ERROR, e.response.text) diff --git a/ibis-server/app/mdl/java_engine.py b/ibis-server/app/mdl/java_engine.py index 4225a7b46..e58e38983 100644 --- a/ibis-server/app/mdl/java_engine.py +++ b/ibis-server/app/mdl/java_engine.py @@ -5,6 +5,7 @@ from orjson import orjson from app.config import get_config +from app.model.error import ErrorCode, ErrorPhase, WrenError wren_engine_endpoint = get_config().wren_engine_endpoint @@ -27,8 +28,10 @@ def __init__(self, end_point: str | None = None): async def dry_plan(self, manifest_str: str, sql: str): if self.client is None: - raise ValueError( - "WREN_ENGINE_ENDPOINT is not set. Cannot call dry_plan without a valid endpoint." + raise WrenError( + ErrorCode.GENERIC_INTERNAL_ERROR, + "WREN_ENGINE_ENDPOINT is not set. Cannot call dry_plan without a valid endpoint.", + phase=ErrorPhase.SQL_PLANNING, ) r = await self.client.request( diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index 6a12eac7d..f23dc006d 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -14,8 +14,8 @@ to_json_base64, ) from app.mdl.java_engine import JavaEngineConnector -from app.model import InternalServerError, UnprocessableEntityError from app.model.data_source import DataSource +from app.model.error import PLANNED_SQL, ErrorCode, ErrorPhase, WrenError # To register custom dialects from ibis library for sqlglot importlib.import_module("ibis.backends.sql.dialects") @@ -47,9 +47,17 @@ def __init__( @tracer.start_as_current_span("transpile", kind=trace.SpanKind.INTERNAL) def _transpile(self, planned_sql: str) -> str: - read = self._get_read_dialect(self.experiment) - write = self._get_write_dialect(self.data_source) - return sqlglot.transpile(planned_sql, read=read, write=write)[0] + try: + read = self._get_read_dialect(self.experiment) + write = self._get_write_dialect(self.data_source) + return sqlglot.transpile(planned_sql, read=read, write=write)[0] + except Exception as e: + raise WrenError( + ErrorCode.SQLGLOT_ERROR, + str(e), + phase=ErrorPhase.SQL_TRANSPILE, + metadata={PLANNED_SQL: planned_sql}, + ) @tracer.start_as_current_span("rewrite", kind=trace.SpanKind.INTERNAL) async def rewrite(self, sql: str) -> str: @@ -102,11 +110,18 @@ async def rewrite( try: return await self.java_engine_connector.dry_plan(manifest_str, sql) except httpx.ConnectError as e: - raise WrenEngineError(f"Can not connect to Java Engine: {e}") + raise WrenError( + ErrorCode.LEGACY_ENGINE_ERROR, f"Can not connect to Java Engine: {e}" + ) except httpx.TimeoutException as e: - raise WrenEngineError(f"Timeout when connecting to Java Engine: {e}") + raise WrenError( + ErrorCode.LEGACY_ENGINE_ERROR, + f"Timeout when connecting to Java Engine: {e}", + ) except httpx.HTTPStatusError as e: - raise RewriteError(e.response.text) + raise WrenError( + ErrorCode.INVALID_SQL, e.response.text, ErrorPhase.SQL_PLANNING + ) @staticmethod def handle_extract_exception(e: Exception): @@ -131,7 +146,7 @@ async def rewrite( sql, ) except Exception as e: - raise RewriteError(str(e)) + raise WrenError(ErrorCode.INVALID_SQL, str(e), ErrorPhase.SQL_PLANNING) @tracer.start_as_current_span("embedded_rewrite", kind=trace.SpanKind.INTERNAL) def rewrite_sync( @@ -144,7 +159,7 @@ def rewrite_sync( ) return session_context.transform_sql(sql) except Exception as e: - raise RewriteError(str(e)) + raise WrenError(ErrorCode.INVALID_SQL, str(e), ErrorPhase.SQL_PLANNING) def get_session_properties(self, properties: dict) -> frozenset | None: if properties is None: @@ -162,14 +177,4 @@ def get_session_properties(self, properties: dict) -> frozenset | None: @staticmethod def handle_extract_exception(e: Exception): - raise RewriteError(str(e)) - - -class RewriteError(UnprocessableEntityError): - def __init__(self, message: str): - super().__init__(message) - - -class WrenEngineError(InternalServerError): - def __init__(self, message: str): - super().__init__(message) + raise WrenError(ErrorCode.INVALID_MDL, str(e), ErrorPhase.MDL_EXTRACTION) diff --git a/ibis-server/app/mdl/substitute.py b/ibis-server/app/mdl/substitute.py index 4a32ec0c0..914c626ee 100644 --- a/ibis-server/app/mdl/substitute.py +++ b/ibis-server/app/mdl/substitute.py @@ -4,8 +4,8 @@ from sqlglot import exp, parse_one from sqlglot.optimizer.scope import build_scope -from app.model import UnprocessableEntityError from app.model.data_source import DataSource +from app.model.error import ErrorCode, ErrorPhase, WrenError from app.util import base64_to_dict tracer = trace.get_tracer(__name__) @@ -40,12 +40,18 @@ def substitute(self, sql: str, write: str | None = None) -> str: # if model name is ambiguous, raise an error duplicate_keys = get_case_insensitive_duplicate_keys(self.model_dict) if model is not None and key.lower() in duplicate_keys: - raise SubstituteError( - f"Ambiguous model: found multiple matches for {source}" + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"Ambiguous model: found multiple matches for {source}", + phase=ErrorPhase.SQL_SUBSTITUTE, ) if model is None: - raise SubstituteError(f"Model not found: {source}") + raise WrenError( + ErrorCode.NOT_FOUND, + f"Model not found: {source}", + phase=ErrorPhase.SQL_SUBSTITUTE, + ) source.replace( exp.Table( @@ -129,8 +135,3 @@ def get_case_insensitive_duplicate_keys(d): duplicates = [key for keys in key_map.values() if len(keys) > 1 for key in keys] return duplicates - - -class SubstituteError(UnprocessableEntityError): - def __init__(self, message: str): - super().__init__(message) diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 58b4c8183..4d7fb10b8 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -1,15 +1,9 @@ from __future__ import annotations -from abc import ABC from enum import Enum from typing import Annotated, Any, Literal, Union from pydantic import BaseModel, Field, SecretStr -from starlette.status import ( - HTTP_404_NOT_FOUND, - HTTP_422_UNPROCESSABLE_ENTITY, - HTTP_500_INTERNAL_SERVER_ERROR, -) manifest_str_field = Field(alias="manifestStr", description="Base64 manifest") connection_info_field = Field(alias="connectionInfo") @@ -526,35 +520,6 @@ class ConfigModel(BaseModel): diagnose: bool -class UnknownIbisError(Exception): - def __init__(self, message): - self.message = f"Unknown ibis error: {message!s}" - - -class CustomHttpError(ABC, Exception): - status_code: int - - -class InternalServerError(CustomHttpError): - status_code = HTTP_500_INTERNAL_SERVER_ERROR - - -class UnprocessableEntityError(CustomHttpError): - status_code = HTTP_422_UNPROCESSABLE_ENTITY - - -class DatabaseTimeoutError(CustomHttpError): - status_code = 504 - - def __init__(self, message: str): - super().__init__(message) - self.message = f"Database timeout error: {message!s}.\nIt seems your database is not responding or the query is taking too long to execute. Please check your database status and query performance." - - -class NotFoundError(CustomHttpError): - status_code = HTTP_404_NOT_FOUND - - class SSLMode(str, Enum): DISABLED = "disabled" ENABLED = "enabled" diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 500c85fb9..31f266c50 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -8,13 +8,25 @@ from json import loads from typing import Any +try: + import clickhouse_connect + + ClickHouseDbError = clickhouse_connect.driver.exceptions.DatabaseError +except ImportError: # pragma: no cover + + class ClickHouseDbError(Exception): + pass + + import ibis import ibis.expr.datatypes as dt import ibis.expr.schema as sch import opendal import pandas as pd +import psycopg import pyarrow as pa import sqlglot.expressions as sge +import trino from duckdb import HTTPException, IOException from google.cloud import bigquery from google.oauth2 import service_account @@ -34,10 +46,14 @@ RedshiftConnectionUnion, RedshiftIAMConnectionInfo, S3FileConnectionInfo, - UnknownIbisError, - UnprocessableEntityError, ) from app.model.data_source import DataSource +from app.model.error import ( + DIALECT_SQL, + ErrorCode, + ErrorPhase, + WrenError, +) from app.model.utils import init_duckdb_gcs, init_duckdb_minio, init_duckdb_s3 # Override datatypes of ibis @@ -52,19 +68,6 @@ def _get_pg_type_names(connection: BaseBackend) -> dict[int, str]: return dict(cur.fetchall()) -class QueryDryRunError(UnprocessableEntityError): - pass - - -class GenericUserError(UnprocessableEntityError): - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def __str__(self) -> str: - return self.message - - class Connector: @tracer.start_as_current_span("connector_init", kind=trace.SpanKind.INTERNAL) def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): @@ -89,13 +92,74 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): self._connector = SimpleConnector(data_source, connection_info) def query(self, sql: str, limit: int | None = None) -> pa.Table: - return self._connector.query(sql, limit) + try: + return self._connector.query(sql, limit) + except ( + WrenError, + TimeoutError, + psycopg.errors.QueryCanceled, + ): + raise + except trino.exceptions.TrinoQueryError as e: + if not e.error_name == "EXCEEDED_TIME_LIMIT": + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_DRY_RUN, + metadata={DIALECT_SQL: sql}, + ) from e + raise + except ClickHouseDbError as e: + if "TIMEOUT_EXCEEDED" not in str(e): + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_EXECUTION, + metadata={DIALECT_SQL: sql}, + ) from e + raise e + except Exception as e: + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + str(e), + phase=ErrorPhase.SQL_EXECUTION, + metadata={DIALECT_SQL: sql}, + ) from e def dry_run(self, sql: str) -> None: try: self._connector.dry_run(sql) + except ( + WrenError, + TimeoutError, + psycopg.errors.QueryCanceled, + ): + raise + except trino.exceptions.TrinoQueryError as e: + if not e.error_name == "EXCEEDED_TIME_LIMIT": + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_DRY_RUN, + metadata={DIALECT_SQL: sql}, + ) from e + raise + except ClickHouseDbError as e: + if "TIMEOUT_EXCEEDED" not in str(e): + raise WrenError( + ErrorCode.INVALID_SQL, + str(e), + phase=ErrorPhase.SQL_DRY_RUN, + metadata={DIALECT_SQL: sql}, + ) from e + raise except Exception as e: - raise QueryDryRunError(f"Exception: {type(e)}, message: {e!s}") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + str(e), + phase=ErrorPhase.SQL_DRY_RUN, + metadata={DIALECT_SQL: sql}, + ) from e def close(self) -> None: """Close the underlying connection.""" @@ -269,8 +333,17 @@ def dry_run(self, sql: str) -> None: # Workaround for ibis issue #10331 if e.args[0] == "'NoneType' object has no attribute 'lower'": error_message = self._describe_sql_for_error_message(sql) - raise QueryDryRunError(f"The sql dry run failed. {error_message}.") - raise UnknownIbisError(e) + raise WrenError( + error_code=ErrorCode.INVALID_SQL, + message=f"The sql dry run failed. {error_message}.", + phase=ErrorPhase.SQL_DRY_RUN, + metadata={DIALECT_SQL: sql}, + ) from e + raise WrenError( + error_code=ErrorCode.IBIS_PROJECT_ERROR, + message=str(e), + phase=ErrorPhase.SQL_DRY_RUN, + ) from e @tracer.start_as_current_span( "describe_sql_for_error_message", kind=trace.SpanKind.CLIENT @@ -412,36 +485,25 @@ def __init__(self, connection_info: ConnectionInfo): @tracer.start_as_current_span("duckdb_query", kind=trace.SpanKind.INTERNAL) def query(self, sql: str, limit: int | None) -> pa.Table: - try: - if limit is None: - # If no limit is specified, we return the full result - return self.connection.execute(sql).fetch_arrow_table() - else: - # If a limit is specified, we slice the result - # DuckDB does not support LIMIT in fetch_arrow_table, so we use slice - # to limit the number of rows returned - return ( - self.connection.execute(sql).fetch_arrow_table().slice(length=limit) - ) - except IOException as e: - raise UnprocessableEntityError(f"Failed to execute query: {e!s}") - except HTTPException as e: - raise UnprocessableEntityError(f"Failed to execute query: {e!s}") + if limit is None: + # If no limit is specified, we return the full result + return self.connection.execute(sql).fetch_arrow_table() + else: + # If a limit is specified, we slice the result + # DuckDB does not support LIMIT in fetch_arrow_table, so we use slice + # to limit the number of rows returned + return self.connection.execute(sql).fetch_arrow_table().slice(length=limit) @tracer.start_as_current_span("duckdb_dry_run", kind=trace.SpanKind.INTERNAL) def dry_run(self, sql: str) -> None: - try: - self.connection.execute(sql) - except IOException as e: - raise QueryDryRunError(f"Failed to execute query: {e!s}") - except HTTPException as e: - raise QueryDryRunError(f"Failed to execute query: {e!s}") + self.connection.execute(sql) def _attach_database(self, connection_info: ConnectionInfo) -> None: db_files = self._list_duckdb_files(connection_info) if not db_files: - raise UnprocessableEntityError( - "No DuckDB files found in the specified path." + raise WrenError( + ErrorCode.DUCKDB_FILE_NOT_FOUND, + "No DuckDB files found in the specified path.", ) for file in db_files: @@ -450,9 +512,13 @@ def _attach_database(self, connection_info: ConnectionInfo) -> None: f"ATTACH DATABASE '{file}' AS \"{os.path.splitext(os.path.basename(file))[0]}\" (READ_ONLY);" ) except IOException as e: - raise UnprocessableEntityError(f"Failed to attach database: {e!s}") + raise WrenError( + ErrorCode.ATTACH_DUCKDB_ERROR, f"Failed to attach database: {e!s}" + ) except HTTPException as e: - raise UnprocessableEntityError(f"Failed to attach database: {e!s}") + raise WrenError( + ErrorCode.ATTACH_DUCKDB_ERROR, f"Failed to attach database: {e!s}" + ) def _list_duckdb_files(self, connection_info: ConnectionInfo) -> list[str]: # This method should return a list of file paths in the DuckDB database @@ -468,7 +534,9 @@ def _list_duckdb_files(self, connection_info: ConnectionInfo) -> list[str]: ) files.append(full_path) except Exception as e: - raise UnprocessableEntityError(f"Failed to list files: {e!s}") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, f"Failed to list files: {e!s}" + ) return files @@ -503,7 +571,10 @@ def __init__(self, connection_info: RedshiftConnectionUnion): password=connection_info.password.get_secret_value(), ) else: - raise ValueError("Invalid Redshift connection_info type") + raise WrenError( + ErrorCode.GENERIC_INTERNAL_ERROR, + "Invalid Redshift connection_info type", + ) # Enable autocommit to prevent holding AccessShareLock indefinitely # This ensures locks are released immediately after query execution diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index a7fa6482c..04af55ff6 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -49,6 +49,7 @@ SSLMode, TrinoConnectionInfo, ) +from app.model.error import ErrorCode, WrenError X_WREN_DB_STATEMENT_TIMEOUT = "x-wren-db-statement_timeout" @@ -177,7 +178,10 @@ def _handle_clickhouse_url( self, parsed: urllib.parse.ParseResult ) -> ClickHouseConnectionInfo: if not parsed.scheme or parsed.scheme != "clickhouse": - raise ValueError("Invalid connection URL for ClickHouse") + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "Invalid connection URL for ClickHouse", + ) kwargs = {} if parsed.username: kwargs["user"] = parsed.username @@ -232,6 +236,10 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend: return getattr(self, f"get_{self.name}_connection")(info) except KeyError: raise NotImplementedError(f"Unsupported data source: {self}") + except WrenError: + raise + except Exception as e: + raise WrenError(ErrorCode.GET_CONNECTION_ERROR, f"{e!s}") from e @staticmethod def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: @@ -401,7 +409,10 @@ def _create_ssl_context(info: ConnectionInfo) -> ssl.SSLContext | None: ) if ssl_mode == SSLMode.VERIFY_CA and not info.ssl_ca: - raise ValueError("SSL CA must be provided when SSL mode is VERIFY CA") + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, + "SSL CA must be provided when SSL mode is VERIFY CA", + ) if not ssl_mode or ssl_mode == SSLMode.DISABLED: return None diff --git a/ibis-server/app/model/error.py b/ibis-server/app/model/error.py new file mode 100644 index 000000000..d64b3c16d --- /dev/null +++ b/ibis-server/app/model/error.py @@ -0,0 +1,125 @@ +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field +from starlette.status import ( + HTTP_404_NOT_FOUND, + HTTP_422_UNPROCESSABLE_ENTITY, + HTTP_500_INTERNAL_SERVER_ERROR, + HTTP_501_NOT_IMPLEMENTED, + HTTP_502_BAD_GATEWAY, + HTTP_504_GATEWAY_TIMEOUT, +) + +DIALECT_SQL = "dialectSql" +PLANNED_SQL = "plannedSql" + + +class ErrorCode(int, Enum): + GENERIC_USER_ERROR = 1 + NOT_FOUND = 2 + MDL_NOT_FOUND = 3 + INVALID_SQL = 4 + INVALID_MDL = 5 + DUCKDB_FILE_NOT_FOUND = 6 + ATTACH_DUCKDB_ERROR = 7 + VALIDATION_RULE_NOT_FOUND = 8 + VALIDATION_ERROR = 9 + VALIDATION_PARAMETER_ERROR = 10 + GET_CONNECTION_ERROR = 11 + INVALID_CONNECTION_INFO = 12 + GENERIC_INTERNAL_ERROR = 100 + LEGACY_ENGINE_ERROR = 101 + NOT_IMPLEMENTED = 102 + IBIS_PROJECT_ERROR = 103 + SQLGLOT_ERROR = 104 + GENERIC_EXTERNAL_ERROR = 200 + DATABASE_TIMEOUT = 201 + + +class ErrorPhase(int, Enum): + REQUEST_RECEIVED = 1 + MDL_EXTRACTION = 2 + SQL_PARSING = 3 + SQL_PLANNING = 4 + SQL_TRANSPILE = 5 + SQL_EXECUTION = 6 + SQL_DRY_RUN = 7 + RESPONSE_GENERATION = 8 + METADATA_FETCHING = 9 + VALIDATION = 10 + SQL_SUBSTITUTE = 11 + + +class ErrorResponse(BaseModel): + model_config = {"populate_by_name": True} + error_code: str = Field(alias="errorCode") + message: str + metadata: dict[str, Any] | None = None + phase: str | None = None + timestamp: str + correlation_id: str | None = Field(alias="correlationId", default=None) + + +class WrenError(Exception): + error_code: ErrorCode + message: str + phase: ErrorPhase | None = None + metadata: dict[str, Any] | None = None + timestamp: str | None = None + + def __init__( + self, + error_code: ErrorCode, + message: str, + phase: ErrorPhase | None = None, + metadata: dict[str, Any] | None = None, + cause: Exception | None = None, + ): + self.error_code = error_code + self.message = message + self.phase = phase + self.metadata = metadata + self.timestamp = datetime.now().isoformat() + + def get_response(self, correlation_id: str | None = None) -> ErrorResponse: + return ErrorResponse( + error_code=self.error_code.name, + message=self.message, + metadata=self.metadata, + phase=self.phase.name if self.phase else None, + timestamp=self.timestamp, + correlation_id=correlation_id, + ) + + def get_http_status_code(self) -> int: + match self.error_code: + case ( + ErrorCode.NOT_FOUND + | ErrorCode.MDL_NOT_FOUND + | ErrorCode.VALIDATION_RULE_NOT_FOUND + ): + return HTTP_404_NOT_FOUND + case ErrorCode.NOT_IMPLEMENTED: + return HTTP_501_NOT_IMPLEMENTED + case ErrorCode.GENERIC_EXTERNAL_ERROR: + return HTTP_502_BAD_GATEWAY + case ErrorCode.DATABASE_TIMEOUT: + return HTTP_504_GATEWAY_TIMEOUT + case e: + if e.value < 100: + return HTTP_422_UNPROCESSABLE_ENTITY + return HTTP_500_INTERNAL_SERVER_ERROR + + +class DatabaseTimeoutError(WrenError): + def __init__( + self, + message: str, + ): + enhanced_message = f"{message!s}.\nIt seems your database is not responding or the query is taking too long to execute. Please check your database status and query performance." + super().__init__( + error_code=ErrorCode.DATABASE_TIMEOUT, + message=enhanced_message, + ) diff --git a/ibis-server/app/model/metadata/canner.py b/ibis-server/app/model/metadata/canner.py index a7699d59d..ed6cfda24 100644 --- a/ibis-server/app/model/metadata/canner.py +++ b/ibis-server/app/model/metadata/canner.py @@ -5,6 +5,7 @@ from gql.transport.aiohttp import AIOHTTPTransport from app.model import CannerConnectionInfo +from app.model.error import ErrorCode, WrenError from app.model.metadata.dto import ( Column, Constraint, @@ -68,7 +69,9 @@ def _get_workspace_id(self, ws_sql_name) -> str: try: return next(ws["id"] for ws in workspaces if ws["sqlName"] == ws_sql_name) except StopIteration: - raise ValueError(f"Workspace {ws_sql_name} not found") + raise WrenError( + ErrorCode.INVALID_CONNECTION_INFO, f"Workspace {ws_sql_name} not found" + ) def _get_metadata(self, workspace_id: str) -> dict: query = gql(""" diff --git a/ibis-server/app/model/metadata/object_storage.py b/ibis-server/app/model/metadata/object_storage.py index 24da66ee0..87197b6d4 100644 --- a/ibis-server/app/model/metadata/object_storage.py +++ b/ibis-server/app/model/metadata/object_storage.py @@ -10,9 +10,9 @@ LocalFileConnectionInfo, MinioFileConnectionInfo, S3FileConnectionInfo, - UnprocessableEntityError, ) from app.model.connector import DuckDBConnector +from app.model.error import ErrorCode, ErrorPhase, WrenError from app.model.metadata.dto import ( Column, RustWrenEngineColumnType, @@ -81,7 +81,11 @@ def get_table_list(self) -> list[Table]: ) unique_tables[table_name].columns = columns except Exception as e: - raise UnprocessableEntityError(f"Failed to list files: {e!s}") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"Failed to list files: {e!s}", + phase=ErrorPhase.METADATA_FETCHING, + ) return list(unique_tables.values()) @@ -343,7 +347,9 @@ def get_constraints(self): def get_version(self): df: pa.Table = self.connection.query("SELECT version()") if df is None: - raise UnprocessableEntityError("Failed to get DuckDB version") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, "Failed to get DuckDB version" + ) if df.num_rows == 0: - raise UnprocessableEntityError("DuckDB version is empty") + raise WrenError(ErrorCode.GENERIC_USER_ERROR, "DuckDB version is empty") return df.column(0).to_pylist()[0] diff --git a/ibis-server/app/model/validator.py b/ibis-server/app/model/validator.py index d792a480c..c432396f1 100644 --- a/ibis-server/app/model/validator.py +++ b/ibis-server/app/model/validator.py @@ -8,8 +8,8 @@ ) from app.mdl.rewriter import Rewriter -from app.model import NotFoundError, UnprocessableEntityError from app.model.connector import Connector +from app.model.error import ErrorCode, ErrorPhase, WrenError from app.util import base64_to_dict rules = ["column_is_valid", "relationship_is_valid", "rlac_condition_syntax_is_valid"] @@ -22,13 +22,18 @@ def __init__(self, connector: Connector, rewriter: Rewriter): async def validate(self, rule: str, parameters: dict, manifest_str: str): if rule not in rules: - raise RuleNotFoundError(rule) + raise WrenError( + ErrorCode.VALIDATION_RULE_NOT_FOUND, + f"The rule `{rule}` is not in the rules, rules: {rules}", + ) try: await getattr(self, f"_validate_{rule}")(parameters, manifest_str) - except ValidationError as e: - raise e + except WrenError: + raise except Exception as e: - raise ValidationError(f"Unknown exception: {type(e)}, message: {e!s}") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, str(e), phase=ErrorPhase.VALIDATION + ) from e async def _validate_column_is_valid( self, parameters: dict[str, str], manifest_str: str @@ -36,23 +41,37 @@ async def _validate_column_is_valid( model_name = parameters.get("modelName") column_name = parameters.get("columnName") if model_name is None: - raise MissingRequiredParameterError("modelName") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + "modelName is required", + phase=ErrorPhase.VALIDATION, + ) if column_name is None: - raise MissingRequiredParameterError("columnName") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + "columnName is required", + phase=ErrorPhase.VALIDATION, + ) try: sql = f'SELECT "{column_name}" FROM "{model_name}" LIMIT 1' rewritten_sql = await self.rewriter.rewrite(sql) self.connector.dry_run(rewritten_sql) except Exception as e: - raise ValidationError(f"Exception: {type(e)}, message: {e!s}") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, str(e), phase=ErrorPhase.VALIDATION + ) from e async def _validate_relationship_is_valid( self, parameters: dict[str, str], manifest_str: str ): relationship_name = parameters.get("relationshipName") if relationship_name is None: - raise MissingRequiredParameterError("relationship") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + "relationshipName is required", + phase=ErrorPhase.VALIDATION, + ) manifest = base64_to_dict(manifest_str) @@ -61,8 +80,10 @@ async def _validate_relationship_is_valid( ) if len(relationship) == 0: - raise ValidationError( - f"Relationship {relationship_name} not found in manifest" + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + f"Relationship {relationship_name} not found in manifest", + phase=ErrorPhase.VALIDATION, ) left_model = self._get_model(manifest, relationship[0]["models"][0]) @@ -117,7 +138,11 @@ def generate_sql_from_type( joinexist.result is_related FROM joinexist""" else: - raise ValidationError(f"Unknown relationship type: {relationship_type}") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + f"Unknown relationship type: {relationship_type}", + phase=ErrorPhase.VALIDATION, + ) def format_result(result): output = {} @@ -144,22 +169,39 @@ def format_result(result): rewritten_sql = await self.rewriter.rewrite(sql) result = self.connector.query(rewritten_sql, limit=1).to_pandas() if not result.get("result").get(0): - raise ValidationError( - f"Relationship {relationship_name} is not valid: {format_result(result)}" + raise WrenError( + ErrorCode.VALIDATION_ERROR, + f"Relationship {relationship_name} is not valid: {format_result(result)}", + phase=ErrorPhase.VALIDATION, ) - + except WrenError: + raise except Exception as e: - raise ValidationError(f"Exception: {type(e)}, message: {e!s}") + raise WrenError( + ErrorCode.GENERIC_INTERNAL_ERROR, str(e), phase=ErrorPhase.VALIDATION + ) from e async def _validate_rlac_condition_syntax_is_valid( self, parameters: dict, manifest_str: str ): if parameters.get("modelName") is None: - raise MissingRequiredParameterError("modelName") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + "modelName is required", + phase=ErrorPhase.VALIDATION, + ) if parameters.get("requiredProperties") is None: - raise MissingRequiredParameterError("requiredProperties") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + "requiredProperties is required", + phase=ErrorPhase.VALIDATION, + ) if parameters.get("condition") is None: - raise MissingRequiredParameterError("condition") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + "condition is required", + phase=ErrorPhase.VALIDATION, + ) model_name = parameters.get("modelName") required_properties = parameters.get("requiredProperties") @@ -183,30 +225,25 @@ async def _validate_rlac_condition_syntax_is_valid( manifest = to_manifest(manifest_str) model = manifest.get_model(model_name) if model is None: - raise ValueError(f"Model {model_name} not found in manifest") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + f"Model {model_name} not found in manifest", + phase=ErrorPhase.VALIDATION, + ) try: validate_rlac_rule(rlac, model) except Exception as e: - raise ValidationError(e) + raise WrenError( + ErrorCode.VALIDATION_ERROR, str(e), phase=ErrorPhase.VALIDATION + ) def _get_model(self, manifest, model_name): models = list(filter(lambda m: m["name"] == model_name, manifest["models"])) if len(models) == 0: - raise ValidationError(f"Model {model_name} not found in manifest") + raise WrenError( + ErrorCode.VALIDATION_PARAMETER_ERROR, + f"Model {model_name} not found in manifest", + phase=ErrorPhase.VALIDATION, + ) return models[0] - - -class ValidationError(UnprocessableEntityError): - def __init__(self, message: str): - super().__init__(message) - - -class RuleNotFoundError(NotFoundError): - def __init__(self, rule: str): - super().__init__(f"The rule `{rule}` is not in the rules, rules: {rules}") - - -class MissingRequiredParameterError(ValidationError): - def __init__(self, parameter: str): - super().__init__(f"Missing required parameter: `{parameter}`") diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index c6ecf20b0..85dda6503 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -23,7 +23,6 @@ from app.mdl.rewriter import Rewriter from app.mdl.substitute import ModelSubstitute from app.model import ( - DatabaseTimeoutError, DryPlanDTO, QueryDTO, TranspileDTO, @@ -31,6 +30,7 @@ ) from app.model.connector import Connector from app.model.data_source import DataSource +from app.model.error import DatabaseTimeoutError from app.model.validator import Validator from app.query_cache import QueryCacheManager from app.routers import v2 diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index 8aebb0b4b..d134e874a 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -2,7 +2,16 @@ import base64 import time -import clickhouse_connect +try: + import clickhouse_connect + + ClickHouseDbError = clickhouse_connect.driver.exceptions.DatabaseError +except ImportError: # pragma: no cover + + class ClickHouseDbError(Exception): + pass + + import datafusion import orjson import pandas as pd @@ -31,8 +40,8 @@ X_CACHE_OVERRIDE_AT, X_WREN_TIMEZONE, ) -from app.model import DatabaseTimeoutError from app.model.data_source import DataSource +from app.model.error import DatabaseTimeoutError from app.model.metadata.metadata import Metadata tracer = trace.get_tracer(__name__) @@ -253,12 +262,14 @@ async def execute_with_timeout(operation, operation_name: str): raise DatabaseTimeoutError( f"{operation_name} timeout after {app_timeout_seconds} seconds" ) - except clickhouse_connect.driver.exceptions.DatabaseError as e: + except ClickHouseDbError as e: if "TIMEOUT_EXCEEDED" in str(e): raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}") + raise except trino.exceptions.TrinoQueryError as e: if e.error_name == "EXCEEDED_TIME_LIMIT": raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}") + raise except psycopg.errors.QueryCanceled as e: raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}") @@ -314,30 +325,6 @@ async def execute_query_with_timeout( ) -async def execute_sample_with_timeout( - connector, - sql: str, - sample_rate: int, - limit: int, - manifest_str: str, -): - """Execute a sample query with a timeout control.""" - task = asyncio.create_task( - asyncio.to_thread( - connector.sample, - sql, - sample_rate=sample_rate, - limit=limit, - manifest_str=manifest_str, - ), - ) - return await _safe_execute_task_with_timeout( - "Sample", - task, - connector, - ) - - async def execute_validate_with_timeout( validator, rule_name: str, diff --git a/ibis-server/tests/routers/v2/connector/test_athena.py b/ibis-server/tests/routers/v2/connector/test_athena.py index fe1b54b07..027776a54 100644 --- a/ibis-server/tests/routers/v2/connector/test_athena.py +++ b/ibis-server/tests/routers/v2/connector/test_athena.py @@ -4,8 +4,6 @@ import orjson import pytest -from app.model.validator import rules - pytestmark = pytest.mark.athena base_url = "/v2/connector/athena" @@ -265,96 +263,6 @@ async def test_query_with_dry_run_and_invalid_sql(client, manifest_str): assert response.text is not None -async def test_validate_with_unknown_rule(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client): response = await client.post( url=f"{base_url}/metadata/tables", diff --git a/ibis-server/tests/routers/v2/connector/test_bigquery.py b/ibis-server/tests/routers/v2/connector/test_bigquery.py index 955338ff9..223c7629b 100644 --- a/ibis-server/tests/routers/v2/connector/test_bigquery.py +++ b/ibis-server/tests/routers/v2/connector/test_bigquery.py @@ -5,8 +5,6 @@ import orjson import pytest -from app.model.validator import rules - pytestmark = pytest.mark.bigquery base_url = "/v2/connector/bigquery" @@ -366,97 +364,6 @@ async def test_custom_datatypes_no_overrides(client, manifest_str): assert result["dtypes"] == {"col": "month_day_nano_interval"} -async def test_validate_with_unknown_rule(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client): def _assert_nested_column(column): if column["nestedColumns"] is not None: diff --git a/ibis-server/tests/routers/v2/connector/test_canner.py b/ibis-server/tests/routers/v2/connector/test_canner.py index be2f6bea3..59f0c7b76 100644 --- a/ibis-server/tests/routers/v2/connector/test_canner.py +++ b/ibis-server/tests/routers/v2/connector/test_canner.py @@ -4,8 +4,6 @@ import pytest from orjson import orjson -from app.model.validator import rules - """ The Canner Enterprise must setup below: - A user with PAT @@ -230,96 +228,6 @@ async def test_query_with_dry_run_and_invalid_sql(client, manifest_str): assert response.text is not None -async def test_validate_with_unknown_rule(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204, response.text - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_dry_plan(client, manifest_str): response = await client.post( url=f"{base_url}/dry-plan", diff --git a/ibis-server/tests/routers/v2/connector/test_clickhouse.py b/ibis-server/tests/routers/v2/connector/test_clickhouse.py index 827db230d..a676f0520 100644 --- a/ibis-server/tests/routers/v2/connector/test_clickhouse.py +++ b/ibis-server/tests/routers/v2/connector/test_clickhouse.py @@ -7,7 +7,7 @@ from testcontainers.clickhouse import ClickHouseContainer from app.model.data_source import X_WREN_DB_STATEMENT_TIMEOUT -from app.model.validator import rules +from app.model.error import ErrorCode from tests.conftest import file_path pytestmark = pytest.mark.clickhouse @@ -318,15 +318,17 @@ async def test_query_to_many_relationship( async def test_query_alias_join(client, manifest_str, clickhouse: ClickHouseContainer): connection_info = _to_connection_info(clickhouse) # ClickHouse does not support alias join - with pytest.raises(Exception): - await client.post( - url=f"{base_url}/query", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "sql": 'SELECT orderstatus FROM ("Orders" o JOIN "Customer" c ON o.custkey = c.custkey) j1 LIMIT 1', - }, - ) + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT orderstatus FROM ("Orders" o JOIN "Customer" c ON o.custkey = c.custkey) j1 LIMIT 1', + }, + ) + + assert response.status_code == 422 + assert response.json()["errorCode"] == ErrorCode.INVALID_SQL.name async def test_query_without_manifest(client, clickhouse: ClickHouseContainer): @@ -409,107 +411,6 @@ async def test_query_with_dry_run_and_invalid_sql( assert response.text is not None -async def test_validate_with_unknown_rule( - client, manifest_str, clickhouse: ClickHouseContainer -): - connection_info = _to_connection_info(clickhouse) - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid( - client, manifest_str, clickhouse: ClickHouseContainer -): - connection_info = _to_connection_info(clickhouse) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str, clickhouse: ClickHouseContainer -): - connection_info = _to_connection_info(clickhouse) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters( - client, manifest_str, clickhouse: ClickHouseContainer -): - connection_info = _to_connection_info(clickhouse) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str, clickhouse: ClickHouseContainer -): - connection_info = _to_connection_info(clickhouse) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client, clickhouse: ClickHouseContainer): connection_info = _to_connection_info(clickhouse) response = await client.post( diff --git a/ibis-server/tests/routers/v2/connector/test_gcs_file.py b/ibis-server/tests/routers/v2/connector/test_gcs_file.py index cddcffef8..f3f1e87f0 100644 --- a/ibis-server/tests/routers/v2/connector/test_gcs_file.py +++ b/ibis-server/tests/routers/v2/connector/test_gcs_file.py @@ -292,7 +292,10 @@ async def test_unsupported_format(client): }, ) assert response.status_code == 422 - assert response.text == "Failed to list files: Unsupported format: unsupported" + assert ( + response.json()["message"] + == "Failed to list files: Unsupported format: unsupported" + ) async def test_list_parquet_files(client): diff --git a/ibis-server/tests/routers/v2/connector/test_local_file.py b/ibis-server/tests/routers/v2/connector/test_local_file.py index 4055d1b63..682d4bbf9 100644 --- a/ibis-server/tests/routers/v2/connector/test_local_file.py +++ b/ibis-server/tests/routers/v2/connector/test_local_file.py @@ -245,7 +245,10 @@ async def test_unsupported_format(client): }, ) assert response.status_code == 422 - assert response.text == "Failed to list files: Unsupported format: unsupported" + assert ( + response.json()["message"] + == "Failed to list files: Unsupported format: unsupported" + ) async def test_list_parquet_files(client): diff --git a/ibis-server/tests/routers/v2/connector/test_minio_file.py b/ibis-server/tests/routers/v2/connector/test_minio_file.py index b5d79b437..4c7fc28ec 100644 --- a/ibis-server/tests/routers/v2/connector/test_minio_file.py +++ b/ibis-server/tests/routers/v2/connector/test_minio_file.py @@ -348,7 +348,10 @@ async def test_unsupported_format(client, minio): }, ) assert response.status_code == 422 - assert response.text == "Failed to list files: Unsupported format: unsupported" + assert ( + response.json()["message"] + == "Failed to list files: Unsupported format: unsupported" + ) async def test_list_parquet_files(client, minio): diff --git a/ibis-server/tests/routers/v2/connector/test_mssql.py b/ibis-server/tests/routers/v2/connector/test_mssql.py index da17b241c..d82ac8c29 100644 --- a/ibis-server/tests/routers/v2/connector/test_mssql.py +++ b/ibis-server/tests/routers/v2/connector/test_mssql.py @@ -8,7 +8,6 @@ from sqlalchemy import text from testcontainers.mssql import SqlServerContainer -from app.model.validator import rules from tests.conftest import file_path pytestmark = pytest.mark.mssql @@ -289,107 +288,6 @@ async def test_query_non_ascii_column(client, manifest_str, mssql: SqlServerCont assert result["columns"] == ["калона"] -async def test_validate_with_unknown_rule( - client, manifest_str, mssql: SqlServerContainer -): - connection_info = _to_connection_info(mssql) - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid( - client, manifest_str, mssql: SqlServerContainer -): - connection_info = _to_connection_info(mssql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str, mssql: SqlServerContainer -): - connection_info = _to_connection_info(mssql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters( - client, manifest_str, mssql: SqlServerContainer -): - connection_info = _to_connection_info(mssql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str, mssql: SqlServerContainer -): - connection_info = _to_connection_info(mssql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client, mssql: SqlServerContainer): connection_info = _to_connection_info(mssql) response = await client.post( diff --git a/ibis-server/tests/routers/v2/connector/test_mysql.py b/ibis-server/tests/routers/v2/connector/test_mysql.py index 87fbf17b2..18301a099 100644 --- a/ibis-server/tests/routers/v2/connector/test_mysql.py +++ b/ibis-server/tests/routers/v2/connector/test_mysql.py @@ -5,12 +5,11 @@ import pymysql import pytest import sqlalchemy -from MySQLdb import OperationalError from sqlalchemy import text from testcontainers.mysql import MySqlContainer from app.model import SSLMode -from app.model.validator import rules +from app.model.error import ErrorCode from tests.conftest import file_path pytestmark = pytest.mark.mysql @@ -277,105 +276,6 @@ async def test_query_with_dry_run_and_invalid_sql( assert response.text is not None -async def test_validate_with_unknown_rule(client, manifest_str, mysql: MySqlContainer): - connection_info = _to_connection_info(mysql) - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid( - client, manifest_str, mysql: MySqlContainer -): - connection_info = _to_connection_info(mysql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str, mysql: MySqlContainer -): - connection_info = _to_connection_info(mysql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters( - client, manifest_str, mysql: MySqlContainer -): - connection_info = _to_connection_info(mysql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str, mysql: MySqlContainer -): - connection_info = _to_connection_info(mysql) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client, mysql: MySqlContainer): connection_info = _to_connection_info(mysql) response = await client.post( @@ -436,32 +336,35 @@ async def test_metadata_db_version(client, mysql: MySqlContainer): @pytest.mark.parametrize( - "ssl_mode, expected_exception, expected_error", + "ssl_mode, error_code, expected_error", [ ( SSLMode.ENABLED, - OperationalError, - "SSL connection error: SSL is required but the server doesn't support it", + ErrorCode.GET_CONNECTION_ERROR, + '(2026, "SSL connection error: SSL is required but the server doesn\'t support it")', ), ( SSLMode.VERIFY_CA, - ValueError, + ErrorCode.INVALID_CONNECTION_INFO, "SSL CA must be provided when SSL mode is VERIFY CA", ), ], ) async def test_connection_invalid_ssl_mode( - client, mysql_ssl_off: MySqlContainer, ssl_mode, expected_exception, expected_error + client, mysql_ssl_off: MySqlContainer, ssl_mode, error_code, expected_error ): connection_info = _to_connection_info(mysql_ssl_off) connection_info["sslMode"] = ssl_mode - with pytest.raises(expected_exception) as excinfo: - await client.post( - url=f"{base_url}/metadata/version", - json={"connectionInfo": connection_info}, - ) - assert expected_error in str(excinfo.value) + response = await client.post( + url=f"{base_url}/metadata/version", + json={"connectionInfo": connection_info}, + ) + + assert response.status_code == 422 + result = response.json() + assert result["errorCode"] == error_code.name + assert result["message"] == expected_error async def test_connection_valid_ssl_mode(client, mysql_ssl_off: MySqlContainer): diff --git a/ibis-server/tests/routers/v2/connector/test_oracle.py b/ibis-server/tests/routers/v2/connector/test_oracle.py index d1e5e6a46..1e857d8e2 100644 --- a/ibis-server/tests/routers/v2/connector/test_oracle.py +++ b/ibis-server/tests/routers/v2/connector/test_oracle.py @@ -7,7 +7,6 @@ from sqlalchemy import text from testcontainers.oracle import OracleDbContainer -from app.model.validator import rules from tests.conftest import file_path pytestmark = pytest.mark.oracle @@ -317,98 +316,6 @@ async def test_query_with_dry_run_and_invalid_sql( ) # Oracle ORA-00942 Error: Table or view does not exist -async def test_validate_with_unknown_rule( - client, manifest_str, oracle: OracleDbContainer -): - connection_info = _to_connection_info(oracle) - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid( - client, manifest_str, oracle: OracleDbContainer -): - connection_info = _to_connection_info(oracle) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str, oracle: OracleDbContainer -): - connection_info = _to_connection_info(oracle) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_parameters( - client, manifest_str, oracle: OracleDbContainer -): - connection_info = _to_connection_info(oracle) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str, oracle: OracleDbContainer -): - connection_info = _to_connection_info(oracle) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client, oracle: OracleDbContainer): connection_info = _to_connection_info(oracle) response = await client.post( @@ -509,7 +416,10 @@ async def test_model_substitute( }, ) assert response.status_code == 422 - assert response.text == 'Ambiguous model: found multiple matches for "ORDERS"' + assert ( + response.json()["message"] + == 'Ambiguous model: found multiple matches for "ORDERS"' + ) def _to_connection_info(oracle: OracleDbContainer): diff --git a/ibis-server/tests/routers/v2/connector/test_postgres.py b/ibis-server/tests/routers/v2/connector/test_postgres.py index ca140bce3..8fee49f9e 100644 --- a/ibis-server/tests/routers/v2/connector/test_postgres.py +++ b/ibis-server/tests/routers/v2/connector/test_postgres.py @@ -4,13 +4,13 @@ import geopandas as gpd import orjson import pandas as pd -import psycopg import pytest import sqlalchemy from sqlalchemy import text from testcontainers.postgres import PostgresContainer from app.model.data_source import X_WREN_DB_STATEMENT_TIMEOUT +from app.model.error import ErrorCode, ErrorPhase from app.model.validator import rules from tests.conftest import file_path @@ -426,19 +426,20 @@ async def test_dry_run_with_connection_url_and_password_with_bracket_should_not_ netloc=f"{part.username}:{password_with_bracket}@{part.hostname}:{part.port}" ).geturl() - with pytest.raises( - psycopg.OperationalError, - match=r'.*FATAL: password authentication failed for user "test".*', - ): - await client.post( - url=f"{base_url}/query", - params={"dryRun": True}, - json={ - "connectionInfo": {"connectionUrl": connection_url}, - "manifestStr": manifest_str, - "sql": 'SELECT * FROM "Orders" LIMIT 1', - }, - ) + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": {"connectionUrl": connection_url}, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + + assert response.status_code == 422 + result = response.json() + assert result["errorCode"] == ErrorCode.GET_CONNECTION_ERROR.name + assert 'password authentication failed for user "test"' in result["message"] async def test_query_with_limit(client, manifest_str, postgres: PostgresContainer): @@ -594,7 +595,8 @@ async def test_validate_with_unknown_rule( ) assert response.status_code == 404 assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" + response.json()["message"] + == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) @@ -667,7 +669,10 @@ async def test_validate_rule_column_is_valid_without_one_parameter( }, ) assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" + result = response.json() + assert result["errorCode"] == ErrorCode.VALIDATION_PARAMETER_ERROR.name + assert result["message"] == "columnName is required" + assert result["phase"] == ErrorPhase.VALIDATION.name response = await client.post( url=f"{base_url}/validate/column_is_valid", @@ -678,7 +683,7 @@ async def test_validate_rule_column_is_valid_without_one_parameter( }, ) assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" + assert response.json()["message"] == "modelName is required" async def test_metadata_list_tables(client, postgres: PostgresContainer): @@ -786,7 +791,7 @@ async def test_model_substitute( "sql": 'SELECT * FROM "orders"', }, ) - assert response.status_code == 422 + assert response.status_code == 404 # Test only have x-user-catalog but have schema in SQL response = await client.post( @@ -814,7 +819,7 @@ async def test_model_substitute( "sql": 'SELECT * FROM "orders"', }, ) - assert response.status_code == 422 + assert response.status_code == 404 # Test only have x-user-schema response = await client.post( @@ -945,8 +950,8 @@ async def test_model_substitute_out_of_scope( "sql": 'SELECT * FROM "Nation" LIMIT 1', }, ) - assert response.status_code == 422 - assert response.text == 'Model not found: "Nation"' + assert response.status_code == 404 + assert response.json()["message"] == 'Model not found: "Nation"' # Test without catalog and schema in SQL but in headers(x-user-xxx) response = await client.post( @@ -958,8 +963,8 @@ async def test_model_substitute_out_of_scope( "sql": 'SELECT * FROM "Nation" LIMIT 1', }, ) - assert response.status_code == 422 - assert response.text == 'Model not found: "Nation"' + assert response.status_code == 404 + assert response.json()["message"] == 'Model not found: "Nation"' async def test_model_substitute_non_existent_column( @@ -976,7 +981,8 @@ async def test_model_substitute_non_existent_column( }, ) assert response.status_code == 422 - assert 'column "x" does not exist' in response.text + result = response.json() + assert 'column "x" does not exist' in result["message"] # Test without catalog and schema in SQL but in headers(x-user-xxx) response = await client.post( @@ -989,7 +995,8 @@ async def test_model_substitute_non_existent_column( }, ) assert response.status_code == 422 - assert 'column "x" does not exist' in response.text + result = response.json() + assert 'column "x" does not exist' in result["message"] async def test_postgis_geometry(client, manifest_str, postgis: PostgresContainer): @@ -1040,9 +1047,11 @@ async def test_connection_timeout(client, manifest_str, postgres: PostgresContai headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second ) assert response.status_code == 504 # Gateway Timeout + result = response.json() + assert result["errorCode"] == ErrorCode.DATABASE_TIMEOUT.name assert ( "Query was cancelled: canceling statement due to statement timeout" - in response.text + in result["message"] ) diff --git a/ibis-server/tests/routers/v2/connector/test_redshift.py b/ibis-server/tests/routers/v2/connector/test_redshift.py index 14b838269..f450f5774 100644 --- a/ibis-server/tests/routers/v2/connector/test_redshift.py +++ b/ibis-server/tests/routers/v2/connector/test_redshift.py @@ -4,8 +4,6 @@ import orjson import pytest -from app.model.validator import rules - pytestmark = pytest.mark.redshift base_url = "/v2/connector/redshift" @@ -268,96 +266,6 @@ async def test_query_with_dry_run_and_invalid_sql(client, manifest_str): assert response.text is not None -async def test_validate_with_unknown_rule(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client): response = await client.post( url=f"{base_url}/metadata/tables", diff --git a/ibis-server/tests/routers/v2/connector/test_s3_file.py b/ibis-server/tests/routers/v2/connector/test_s3_file.py index 7f2b2be98..42cb92323 100644 --- a/ibis-server/tests/routers/v2/connector/test_s3_file.py +++ b/ibis-server/tests/routers/v2/connector/test_s3_file.py @@ -292,7 +292,10 @@ async def test_unsupported_format(client): }, ) assert response.status_code == 422 - assert response.text == "Failed to list files: Unsupported format: unsupported" + assert ( + response.json()["message"] + == "Failed to list files: Unsupported format: unsupported" + ) async def test_list_parquet_files(client): diff --git a/ibis-server/tests/routers/v2/connector/test_snowflake.py b/ibis-server/tests/routers/v2/connector/test_snowflake.py index 6c3789da5..771a56f5f 100644 --- a/ibis-server/tests/routers/v2/connector/test_snowflake.py +++ b/ibis-server/tests/routers/v2/connector/test_snowflake.py @@ -4,8 +4,6 @@ import orjson import pytest -from app.model.validator import rules - pytestmark = pytest.mark.snowflake base_url = "/v2/connector/snowflake" @@ -272,96 +270,6 @@ async def test_query_with_dry_run_and_invalid_sql(client, manifest_str): assert response.text is not None -async def test_validate_with_unknown_rule(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters(client, manifest_str): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str -): - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client): response = await client.post( url=f"{base_url}/metadata/tables", diff --git a/ibis-server/tests/routers/v2/connector/test_trino.py b/ibis-server/tests/routers/v2/connector/test_trino.py index 08c7cc8f5..2817dfedb 100644 --- a/ibis-server/tests/routers/v2/connector/test_trino.py +++ b/ibis-server/tests/routers/v2/connector/test_trino.py @@ -7,7 +7,6 @@ from trino.dbapi import connect from app.model.data_source import X_WREN_DB_STATEMENT_TIMEOUT -from app.model.validator import rules pytestmark = pytest.mark.trino @@ -272,105 +271,6 @@ async def test_query_with_dry_run_and_invalid_sql( assert response.text is not None -async def test_validate_with_unknown_rule(client, manifest_str, trino: TrinoContainer): - connection_info = _to_connection_info(trino) - response = await client.post( - url=f"{base_url}/validate/unknown_rule", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 404 - assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" - ) - - -async def test_validate_rule_column_is_valid( - client, manifest_str, trino: TrinoContainer -): - connection_info = _to_connection_info(trino) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 204 - - -async def test_validate_rule_column_is_valid_with_invalid_parameters( - client, manifest_str, trino: TrinoContainer -): - connection_info = _to_connection_info(trino) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "X", "columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders", "columnName": "X"}, - }, - ) - assert response.status_code == 422 - - -async def test_validate_rule_column_is_valid_without_parameters( - client, manifest_str, trino: TrinoContainer -): - connection_info = _to_connection_info(trino) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={"connectionInfo": connection_info, "manifestStr": manifest_str}, - ) - assert response.status_code == 422 - result = response.json() - assert result["detail"][0] is not None - assert result["detail"][0]["type"] == "missing" - assert result["detail"][0]["loc"] == ["body", "parameters"] - assert result["detail"][0]["msg"] == "Field required" - - -async def test_validate_rule_column_is_valid_without_one_parameter( - client, manifest_str, trino: TrinoContainer -): - connection_info = _to_connection_info(trino) - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"modelName": "Orders"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" - - response = await client.post( - url=f"{base_url}/validate/column_is_valid", - json={ - "connectionInfo": connection_info, - "manifestStr": manifest_str, - "parameters": {"columnName": "orderkey"}, - }, - ) - assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" - - async def test_metadata_list_tables(client, trino: TrinoContainer): connection_info = _to_connection_info(trino) response = await client.post( diff --git a/ibis-server/tests/routers/v2/test_relationship_valid.py b/ibis-server/tests/routers/v2/test_relationship_valid.py index 50837bdfc..9ab003f2b 100644 --- a/ibis-server/tests/routers/v2/test_relationship_valid.py +++ b/ibis-server/tests/routers/v2/test_relationship_valid.py @@ -118,7 +118,7 @@ async def test_validation_relationship_not_found( ) assert response.status_code == 422 - assert response.text == "Relationship not_found not found in manifest" + assert response.json()["message"] == "Relationship not_found not found in manifest" connection_info = _to_connection_info(postgres) response = await client.post( @@ -131,7 +131,7 @@ async def test_validation_relationship_not_found( ) assert response.status_code == 422 - assert response.text == "Missing required parameter: `relationship`" + assert response.json()["message"] == "relationshipName is required" async def test_validation_failure(client, manifest_str, postgres: PostgresContainer): @@ -147,9 +147,8 @@ async def test_validation_failure(client, manifest_str, postgres: PostgresContai assert response.status_code == 422 assert ( - response.text - == "Exception: , message: Relationship invalid_t1_many_t2_id is not valid: {'result': 'False', 'is_related': 'True', " - "'left_table_unique': 'False', 'right_table_unique': 'True'}" + response.json()["message"] + == "Relationship invalid_t1_many_t2_id is not valid: {'result': 'False', 'is_related': 'True', 'left_table_unique': 'False', 'right_table_unique': 'True'}" ) diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py index d3736e500..1fb400c18 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_model_substitute.py @@ -103,7 +103,7 @@ async def test_model_substitute( "sql": 'SELECT * FROM "orders"', }, ) - assert response.status_code == 422 + assert response.status_code == 404 # Test only have x-user-catalog but have schema in SQL response = await client.post( @@ -135,7 +135,7 @@ async def test_model_substitute( "sql": 'SELECT * FROM "orders"', }, ) - assert response.status_code == 422 + assert response.status_code == 404 # Test only have x-user-schema with no catalog mdl response = await client.post( @@ -267,8 +267,8 @@ async def test_model_substitute_out_of_scope(client, manifest_str, connection_in "sql": 'SELECT * FROM "Nation" LIMIT 1', }, ) - assert response.status_code == 422 - assert response.text == 'Model not found: "Nation"' + assert response.status_code == 404 + assert response.json()["message"] == 'Model not found: "Nation"' # Test without catalog and schema in SQL but in headers(x-user-xxx) response = await client.post( @@ -283,8 +283,8 @@ async def test_model_substitute_out_of_scope(client, manifest_str, connection_in "sql": 'SELECT * FROM "Nation" LIMIT 1', }, ) - assert response.status_code == 422 - assert response.text == 'Model not found: "Nation"' + assert response.status_code == 404 + assert response.json()["message"] == 'Model not found: "Nation"' async def test_model_substitute_non_existent_column( diff --git a/ibis-server/tests/routers/v3/connector/postgres/test_validate.py b/ibis-server/tests/routers/v3/connector/postgres/test_validate.py index f6a12014d..f3a342774 100644 --- a/ibis-server/tests/routers/v3/connector/postgres/test_validate.py +++ b/ibis-server/tests/routers/v3/connector/postgres/test_validate.py @@ -40,7 +40,8 @@ async def test_validate_with_unknown_rule(client, manifest_str, connection_info) ) assert response.status_code == 404 assert ( - response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" + response.json()["message"] + == f"The rule `unknown_rule` is not in the rules, rules: {rules}" ) @@ -107,7 +108,7 @@ async def test_validate_rule_column_is_valid_without_one_parameter( }, ) assert response.status_code == 422 - assert response.text == "Missing required parameter: `columnName`" + assert response.json()["message"] == "columnName is required" response = await client.post( url=f"{base_url}/validate/column_is_valid", @@ -118,7 +119,7 @@ async def test_validate_rule_column_is_valid_without_one_parameter( }, ) assert response.status_code == 422 - assert response.text == "Missing required parameter: `modelName`" + assert response.json()["message"] == "modelName is required" async def test_validate_rlac_condition_syntax_is_valid( @@ -173,6 +174,6 @@ async def test_validate_rlac_condition_syntax_is_valid( assert response.status_code == 422 assert ( - response.text + response.json()["message"] == "Error during planning: The session property @session_not_found is used for `rlac_validation` rule, but not found in the session properties" ) diff --git a/ibis-server/wren/__init__.py b/ibis-server/wren/__init__.py index b72fc2af2..667897129 100644 --- a/ibis-server/wren/__init__.py +++ b/ibis-server/wren/__init__.py @@ -9,6 +9,7 @@ from app import model from app.model import ConnectionInfo from app.model.data_source import DataSource +from app.model.error import ErrorCode, WrenError __all__ = ["Context", "Task", "create_session_context", "model"] @@ -46,16 +47,18 @@ def create_session_context( from .session import Context # noqa: PLC0415 if not mdl_path: - raise ValueError("mdl_path must be provided") + raise WrenError(ErrorCode.GENERIC_USER_ERROR, "mdl_path must be provided") if not data_source: - raise ValueError("data_source must be provided") + raise WrenError(ErrorCode.GENERIC_USER_ERROR, "data_source must be provided") data_source = DataSource(data_source) with open(mdl_path) as f: if not f.readable(): - raise ValueError(f"Cannot read MDL file at {mdl_path}") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, f"Cannot read MDL file at {mdl_path}" + ) try: manifest = json.load(f) manifest_base64 = ( @@ -64,7 +67,10 @@ def create_session_context( else None ) except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in MDL file at {mdl_path}: {e}") from e + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"Invalid JSON in MDL file at {mdl_path}: {e}", + ) from e return Context( data_source=data_source, diff --git a/ibis-server/wren/__main__.py b/ibis-server/wren/__main__.py index cabf26f8b..5568175e9 100644 --- a/ibis-server/wren/__main__.py +++ b/ibis-server/wren/__main__.py @@ -5,6 +5,7 @@ import sys from app.model.data_source import DataSource +from app.model.error import ErrorCode, WrenError from wren import create_session_context @@ -22,12 +23,13 @@ def main(): if len(sys.argv) > 3: with open(connection_info_path := sys.argv[3]) as f: if not f.readable(): - raise ValueError( - f"Cannot read connection info file at {connection_info_path}" + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + f"Cannot read connection info file at {connection_info_path}", ) connection_info = json.load(f) - # The connection_info file proeduced by Wren AI dbt integration + # The connection_info file produced by Wren AI dbt integration # contains a "type" field to indicate the data source type. # If it is present, we need to use the `get_connection_info` method # of the DataSource class to get the connection info. diff --git a/ibis-server/wren/session/__init__.py b/ibis-server/wren/session/__init__.py index 6a3e14304..63ce75b7c 100644 --- a/ibis-server/wren/session/__init__.py +++ b/ibis-server/wren/session/__init__.py @@ -9,6 +9,7 @@ from app.model import ConnectionInfo from app.model.connector import Connector from app.model.data_source import DataSource +from app.model.error import ErrorCode, WrenError from app.util import to_json @@ -135,7 +136,10 @@ def _get_write_dialect( def dry_run(self): """Perform a dry run of the dialect SQL without executing it.""" if self.dialect_sql is None: - raise ValueError("Dialect SQL is not set. Call transpile() first.") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + "Dialect SQL is not set. Call transpile() first.", + ) self.context.get_connector().dry_run(self.dialect_sql) def execute(self, limit: int | None = None): @@ -147,17 +151,24 @@ def execute(self, limit: int | None = None): The maximum number of rows to return. If None, returns all rows. """ if self.context.connection_info is None: - raise ValueError( - "Connection info is not set. Cannot execute without connection info." + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + "Connection info is not set. Cannot execute without connection info.", ) if self.dialect_sql is None: - raise ValueError("Dialect SQL is not set. Call transpile() first.") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + "Dialect SQL is not set. Call transpile() first.", + ) self.results = self.context.get_connector().query(self.dialect_sql, limit) return self def formatted_result(self): """Get the formatted result of the executed task.""" if self.results is None: - raise ValueError("Results are not set. Call execute() first.") + raise WrenError( + ErrorCode.GENERIC_USER_ERROR, + "Results are not set. Call execute() first.", + ) return to_json(self.results, self.properties, self.context.data_source)