Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
X_CACHE_OVERRIDE_AT = "X-Cache-Override-At"


# Rebuild model to validate the dto is correct via validation of the pydantic
# Validate the dto by building the specific connection info from the data source
def verify_query_dto(data_source: DataSource, dto: QueryDTO):
data_source.get_dto_type()(**dto.model_dump(by_alias=True))
# Use data_source.get_connection_info to validate the connection_info
# This will ensure the connection_info can be properly parsed for the specific data source
data_source.get_connection_info(dto.connection_info, {})


def get_wren_headers(request: Request) -> Headers:
Expand Down
21 changes: 16 additions & 5 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC
from enum import Enum
from typing import Annotated, Literal, Union
from typing import Annotated, Any, Literal, Union

from pydantic import BaseModel, Field, SecretStr
from starlette.status import (
Expand Down Expand Up @@ -30,7 +30,7 @@ def to_key_string(self) -> str:
class QueryDTO(BaseModel):
sql: str
manifest_str: str = manifest_str_field
connection_info: ConnectionInfo = connection_info_field
connection_info: dict[str, Any] | ConnectionInfo = connection_info_field


class QueryBigQueryDTO(QueryDTO):
Expand Down Expand Up @@ -165,6 +165,16 @@ class ClickHouseConnectionInfo(BaseConnectionInfo):
password: SecretStr | None = Field(
description="the password of your database", examples=["password"], default=None
)
secure: bool = Field(
description="Whether or not to use an authenticated endpoint",
default=False,
examples=[True, False],
)
settings: dict[str, str] | None = Field(
description="Additional settings for ClickHouse connection",
default=None,
examples=[{"max_execution_time": "60"}],
)
kwargs: dict[str, str] | None = Field(
description="Client specific keyword arguments", default=None
)
Expand Down Expand Up @@ -211,7 +221,7 @@ class MySqlConnectionInfo(BaseConnectionInfo):
)
ssl_mode: SecretStr | None = Field(
alias="sslMode",
default="ENABLED",
default=SecretStr("ENABLED"),
description="Use ssl connection or not. The default value is `ENABLED` because MySQL uses `caching_sha2_password` by default and the driver MySQLdb support caching_sha2_password with ssl only.",
examples=["DISABLED", "ENABLED", "VERIFY_CA"],
)
Expand Down Expand Up @@ -468,6 +478,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
AthenaConnectionInfo
| BigQueryConnectionInfo
| CannerConnectionInfo
| ClickHouseConnectionInfo
| ConnectionUrl
| MSSqlConnectionInfo
| MySqlConnectionInfo
Expand All @@ -487,7 +498,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
class ValidateDTO(BaseModel):
manifest_str: str = manifest_str_field
parameters: dict
connection_info: ConnectionInfo = connection_info_field
connection_info: dict[str, Any] | ConnectionInfo = connection_info_field


class AnalyzeSQLDTO(BaseModel):
Expand All @@ -507,7 +518,7 @@ class DryPlanDTO(BaseModel):

class TranspileDTO(BaseModel):
manifest_str: str = manifest_str_field
connection_info: ConnectionInfo = connection_info_field
connection_info: dict[str, Any] | ConnectionInfo = connection_info_field
sql: str


Expand Down
50 changes: 48 additions & 2 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import base64
import ssl
import urllib
from enum import Enum, StrEnum, auto
from json import loads
from typing import Any
from urllib.parse import unquote_plus

import ibis
from google.oauth2 import service_account
Expand All @@ -15,6 +18,7 @@
CannerConnectionInfo,
ClickHouseConnectionInfo,
ConnectionInfo,
ConnectionUrl,
GcsFileConnectionInfo,
LocalFileConnectionInfo,
MinioFileConnectionInfo,
Expand Down Expand Up @@ -79,7 +83,7 @@ def get_dto_type(self):
raise NotImplementedError(f"Unsupported data source: {self}")

def get_connection_info(
self, data: dict | ConnectionInfo, headers: dict
self, data: dict[str, Any] | ConnectionInfo, headers: dict[str, str]
) -> ConnectionInfo:
"""Build a ConnectionInfo object from the provided data and add requried configuration from headers."""
if isinstance(data, ConnectionInfo):
Expand All @@ -99,11 +103,26 @@ def get_connection_info(
options += f"-c statement_timeout={headers.get(X_WREN_DB_STATEMENT_TIMEOUT, 180)}s"
kwargs["options"] = options
info.kwargs = kwargs

case DataSource.clickhouse:
session_timeout = headers.get(X_WREN_DB_STATEMENT_TIMEOUT, 180)
if info.settings is None:
info.settings = {}
if "max_execution_time" not in info.settings:
info.settings["max_execution_time"] = int(session_timeout)
return info

def _build_connection_info(self, data: dict) -> ConnectionInfo:
"""Build a ConnectionInfo object from the provided data."""
# Check if data contains connectionUrl for connection string-based connections
if "connectionUrl" in data or "connection_url" in data:
if self == DataSource.clickhouse:
return self._handle_clickhouse_url(
urllib.parse.urlparse(
data.get("connectionUrl", data.get("connection_url"))
)
)
return ConnectionUrl.model_validate(data)

match self:
case DataSource.athena:
return AthenaConnectionInfo.model_validate(data)
Expand Down Expand Up @@ -140,6 +159,32 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo:
case _:
raise NotImplementedError(f"Unsupported data source: {self}")

def _handle_clickhouse_url(
self, parsed: urllib.parse.ParseResult
) -> ClickHouseConnectionInfo:
if not parsed.scheme or parsed.scheme != "clickhouse":
raise ValueError("Invalid connection URL for ClickHouse")
kwargs = {}
if parsed.username:
kwargs["user"] = parsed.username
if parsed.password:
kwargs["password"] = unquote_plus(parsed.password)
if parsed.hostname:
kwargs["host"] = parsed.hostname
if parsed.port:
kwargs["port"] = str(parsed.port)
if database := parsed.path[1:]:
kwargs["database"] = database
parsed_kwargs = dict(urllib.parse.parse_qsl(parsed.query))
if "secure" in parsed_kwargs:
kwargs["secure"] = self._safe_strtobool(parsed_kwargs["secure"])
parsed_kwargs.pop("secure")
kwargs["kwargs"] = parsed_kwargs
return ClickHouseConnectionInfo(**kwargs)

def _safe_strtobool(self, val: str) -> bool:
return val.lower() in {"1", "true", "yes", "y"}


class DataSourceExtension(Enum):
athena = QueryAthenaDTO
Expand Down Expand Up @@ -222,6 +267,7 @@ def get_clickhouse_connection(info: ClickHouseConnectionInfo) -> BaseBackend:
database=info.database.get_secret_value(),
user=info.user.get_secret_value(),
password=(info.password and info.password.get_secret_value()),
settings=info.settings if info.settings else dict(),
**info.kwargs if info.kwargs else dict(),
)

Expand Down
2 changes: 1 addition & 1 deletion ibis-server/app/model/metadata/dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class MetadataDTO(BaseModel):
connection_info: ConnectionInfo = Field(alias="connectionInfo")
connection_info: dict[str, Any] | ConnectionInfo = Field(alias="connectionInfo")


class RustWrenEngineColumnType(Enum):
Expand Down
10 changes: 10 additions & 0 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from app.mdl.rewriter import Rewriter
from app.mdl.substitute import ModelSubstitute
from app.model import (
DatabaseTimeoutError,
DryPlanDTO,
QueryDTO,
TranspileDTO,
Expand Down Expand Up @@ -173,6 +174,9 @@ async def query(
response = ORJSONResponse(to_json(result, headers, data_source=data_source))
update_response_headers(response, cache_headers)
return response
except DatabaseTimeoutError:
# won't fallback to v2 if timeout
raise
except Exception as e:
is_fallback_disable = bool(
headers.get(X_WREN_FALLBACK_DISABLE)
Expand Down Expand Up @@ -330,6 +334,9 @@ async def validate(
dto.manifest_str,
)
return Response(status_code=204)
except DatabaseTimeoutError:
# won't fallback to v2 if timeout
raise
except Exception as e:
is_fallback_disable = bool(
headers.get(X_WREN_FALLBACK_DISABLE)
Expand Down Expand Up @@ -410,6 +417,9 @@ async def model_substitute(
rewritten_sql,
)
return sql
except DatabaseTimeoutError:
# won't fallback to v2 if timeout
raise
except Exception as e:
is_fallback_disable = bool(
headers.get(X_WREN_FALLBACK_DISABLE)
Expand Down
4 changes: 4 additions & 0 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import base64
import time

import clickhouse_connect
import datafusion
import orjson
import pandas as pd
Expand Down Expand Up @@ -267,6 +268,9 @@ async def execute_with_timeout(operation, operation_name: str):
raise DatabaseTimeoutError(
f"{operation_name} timeout after {app_timeout_seconds} seconds"
)
except clickhouse_connect.driver.exceptions.DatabaseError as e:
if "TIMEOUT_EXCEEDED" in str(e):
raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}")
except psycopg.errors.QueryCanceled as e:
raise DatabaseTimeoutError(f"{operation_name} was cancelled: {e}")

Expand Down
1 change: 1 addition & 0 deletions ibis-server/resources/function_list/clickhouse.csv
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ scalar,rand,UInt32,,,"Returns random number."
scalar,rand64,UInt64,,,"Returns random 64-bit number."
scalar,e,Float,,,"Returns value of e."
scalar,yesterday,Date,,,"Returns yesterday's date."
scalar,sleep,,seconds,Int64,"Pauses execution for a specified number of seconds."
32 changes: 32 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
from testcontainers.clickhouse import ClickHouseContainer

from app.model.data_source import X_WREN_DB_STATEMENT_TIMEOUT
from app.model.validator import rules
from tests.conftest import file_path

Expand Down Expand Up @@ -566,6 +567,37 @@ async def test_metadata_db_version(client, clickhouse: ClickHouseContainer):
assert response.text is not None


async def test_connection_timeout(
client, manifest_str, clickhouse: ClickHouseContainer
):
connection_info = _to_connection_info(clickhouse)
# Set a very short timeout to force a timeout error
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT sleep(3)", # This will take longer than the default timeout
},
headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second
)
assert response.status_code == 504 # Gateway Timeout
assert "Query was cancelled:" in response.text

connection_info = _to_connection_url(clickhouse)
response = await client.post(
url=f"{base_url}/query",
json={
"connectionInfo": {"connectionUrl": connection_info},
"manifestStr": manifest_str,
"sql": "SELECT sleep(3)", # This will take longer than the default timeout
},
headers={X_WREN_DB_STATEMENT_TIMEOUT: "1"}, # Set timeout to 1 second
)
assert response.status_code == 504 # Gateway Timeout
assert "Query was cancelled:" in response.text


def _to_connection_info(db: ClickHouseContainer):
return {
"host": db.get_container_host_ip(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def test_function_list(client):
response = await client.get(url=f"{base_url}/functions")
assert response.status_code == 200
result = response.json()
assert len(result) == DATAFUSION_FUNCTION_COUNT + 5
assert len(result) == DATAFUSION_FUNCTION_COUNT + 6
the_func = next(filter(lambda x: x["name"] == "uniq", result))
assert the_func == {
"name": "uniq",
Expand Down