Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/skills-check.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Skills Version Check

on:
pull_request:
paths:
- "skills/**"

jobs:
version-parity:
name: Check skills/versions.json parity
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Verify versions.json matches SKILL.md frontmatter
run: bash skills/check-versions.sh
3 changes: 3 additions & 0 deletions ibis-server/app/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

# Validate the dto by building the specific connection info from the data source
def verify_query_dto(data_source: DataSource, dto: QueryDTO):
# Skip inline validation when using a file path; connection info is validated at query time
if dto.connection_file_path:
return
# Use data_source.get_connection_info to validate the connection_info
# This will ensure the connection_info can be properly parsed for the specific data source
data_source.get_connection_info(dto.connection_info, {})
Expand Down
40 changes: 35 additions & 5 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from enum import Enum
from typing import Annotated, Any, Literal, Union

from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, SecretStr, model_validator

from app.model.error import ErrorCode, WrenError

manifest_str_field = Field(alias="manifestStr", description="Base64 manifest")
connection_info_field = Field(alias="connectionInfo")
connection_info_field = Field(alias="connectionInfo", default=None)


class BaseConnectionInfo(BaseModel):
Expand All @@ -26,7 +26,17 @@ def to_key_string(self) -> str:
class QueryDTO(BaseModel):
sql: str
manifest_str: str = manifest_str_field
connection_info: dict[str, Any] | ConnectionInfo = connection_info_field
connection_info: dict[str, Any] | ConnectionInfo | None = connection_info_field
connection_file_path: str | None = Field(alias="connectionFilePath", default=None)

@model_validator(mode="after")
def check_connection_source(self):
if self.connection_info is None and self.connection_file_path is None:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
"Either connectionInfo or connectionFilePath must be provided",
)
return self


class QueryBigQueryDTO(QueryDTO):
Expand Down Expand Up @@ -654,7 +664,17 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
class ValidateDTO(BaseModel):
manifest_str: str = manifest_str_field
parameters: dict
connection_info: dict[str, Any] | ConnectionInfo = connection_info_field
connection_info: dict[str, Any] | ConnectionInfo | None = connection_info_field
connection_file_path: str | None = Field(alias="connectionFilePath", default=None)

@model_validator(mode="after")
def check_connection_source(self):
if self.connection_info is None and self.connection_file_path is None:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
"Either connectionInfo or connectionFilePath must be provided",
)
return self


class AnalyzeSQLDTO(BaseModel):
Expand All @@ -674,9 +694,19 @@ class DryPlanDTO(BaseModel):

class TranspileDTO(BaseModel):
manifest_str: str = manifest_str_field
connection_info: dict[str, Any] | ConnectionInfo = connection_info_field
connection_info: dict[str, Any] | ConnectionInfo | None = connection_info_field
connection_file_path: str | None = Field(alias="connectionFilePath", default=None)
sql: str

@model_validator(mode="after")
def check_connection_source(self):
if self.connection_info is None and self.connection_file_path is None:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
"Either connectionInfo or connectionFilePath must be provided",
)
return self


class ConfigModel(BaseModel):
diagnose: bool
Expand Down
17 changes: 15 additions & 2 deletions ibis-server/app/model/metadata/dto.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from enum import Enum
from typing import Any

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from app.model import ConnectionInfo
from app.model.data_source import DataSource
from app.model.error import ErrorCode, WrenError


class V2MetadataDTO(BaseModel):
Expand All @@ -16,10 +17,22 @@ class FilterInfo(BaseModel):


class MetadataDTO(BaseModel):
connection_info: dict[str, Any] | ConnectionInfo = Field(alias="connectionInfo")
connection_info: dict[str, Any] | ConnectionInfo | None = Field(
alias="connectionInfo", default=None
)
connection_file_path: str | None = Field(alias="connectionFilePath", default=None)
table_limit: int | None = Field(alias="limit", default=None)
filter_info: dict[str, Any] | None = Field(alias="filterInfo", default=None)

@model_validator(mode="after")
def check_connection_source(self):
if self.connection_info is None and self.connection_file_path is None:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
"Either connectionInfo or connectionFilePath must be provided",
)
return self


class BigQueryFilterInfo(FilterInfo):
projects: list["ProjectDatasets"] | None = None
Expand Down
13 changes: 7 additions & 6 deletions ibis-server/app/routers/v2/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
execute_validate_with_timeout,
get_fallback_message,
pushdown_limit,
resolve_connection_info,
set_attribute,
to_json,
update_response_headers,
Expand Down Expand Up @@ -93,7 +94,7 @@ async def query(
if cache_enable:
span_name += "_cache_enable"
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
# Convert headers to dict for cache manager
headers_dict = dict(headers) if headers else None
Expand Down Expand Up @@ -232,7 +233,7 @@ async def validate(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
validator = Validator(
Connector(data_source, connection_info),
Expand Down Expand Up @@ -273,7 +274,7 @@ async def get_table_list(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
if isinstance(connection_info, BigQueryProjectConnectionInfo):
raise WrenError(
Expand Down Expand Up @@ -302,7 +303,7 @@ async def get_constraints(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
if isinstance(connection_info, BigQueryProjectConnectionInfo):
raise WrenError(
Expand All @@ -324,7 +325,7 @@ async def get_db_version(
headers: Annotated[Headers, Depends(get_wren_headers)] = None,
) -> str:
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
metadata = MetadataFactory.get_metadata(data_source, connection_info)
return await execute_get_version_with_timeout(metadata)
Expand Down Expand Up @@ -398,7 +399,7 @@ async def model_substitute(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute(
dto.sql, write="trino"
Expand Down
11 changes: 6 additions & 5 deletions ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
execute_query_with_timeout,
execute_validate_with_timeout,
pushdown_limit,
resolve_connection_info,
safe_strtobool,
set_attribute,
to_json,
Expand Down Expand Up @@ -91,7 +92,7 @@ async def query(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
# Convert headers to dict for cache manager
headers_dict = dict(headers) if headers else None
Expand Down Expand Up @@ -357,7 +358,7 @@ async def validate(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
try:
validator = Validator(
Expand Down Expand Up @@ -485,7 +486,7 @@ async def model_substitute(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
try:
sql = ModelSubstitute(data_source, dto.manifest_str, headers).substitute(
Expand Down Expand Up @@ -569,7 +570,7 @@ async def get_table_list(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
metadata = MetadataFactory.get_metadata(data_source, connection_info)
filter_info = get_filter_info(data_source, dto.filter_info or {})
Expand All @@ -595,7 +596,7 @@ async def get_schema_list(
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
dto.connection_info, dict(headers)
resolve_connection_info(dto), dict(headers)
)
metadata = MetadataFactory.get_metadata(data_source, connection_info)
filter_info = get_filter_info(data_source, dto.filter_info or {})
Expand Down
42 changes: 41 additions & 1 deletion ibis-server/app/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import base64
import json
import pathlib
import time

try:
Expand Down Expand Up @@ -41,7 +43,7 @@ class ClickHouseDbError(Exception):
X_WREN_TIMEZONE,
)
from app.model.data_source import DataSource
from app.model.error import DatabaseTimeoutError
from app.model.error import DatabaseTimeoutError, ErrorCode, WrenError
from app.model.metadata.bigquery import BigQueryMetadata
from app.model.metadata.dto import FilterInfo
from app.model.metadata.metadata import Metadata
Expand All @@ -53,6 +55,44 @@ class ClickHouseDbError(Exception):
Wren AI team are appreciate if you can provide the error messages and related logs for us."


def resolve_connection_info(dto) -> dict:
"""Return connection info dict from either a file path or the inline DTO field.

When connectionFilePath is used, CONNECTION_FILE_ROOT must be set to the
directory that is allowed to be read. Requests are rejected if the env var
is absent or the resolved path escapes that directory.
"""
import os

if getattr(dto, "connection_file_path", None):
allowed_root = os.environ.get("CONNECTION_FILE_ROOT")
if not allowed_root:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
"connectionFilePath requires the CONNECTION_FILE_ROOT environment variable to be set",
)
path = pathlib.Path(dto.connection_file_path).resolve()
if not path.is_relative_to(pathlib.Path(allowed_root).resolve()):
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
f"Connection file path is outside the allowed directory: {dto.connection_file_path}",
)
try:
with open(path) as f:
return json.load(f)
except FileNotFoundError:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
f"Connection file not found: {dto.connection_file_path}",
)
except json.JSONDecodeError as e:
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
f"Invalid JSON in connection file: {e}",
)
return dto.connection_info


@tracer.start_as_current_span("base64_to_dict", kind=trace.SpanKind.INTERNAL)
def base64_to_dict(base64_str: str) -> dict:
return orjson.loads(base64.b64decode(base64_str).decode("utf-8"))
Expand Down
7 changes: 3 additions & 4 deletions ibis-server/tests/routers/v2/connector/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,9 @@ async def test_query_without_connection_info(client, manifest_str):
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "connectionInfo"]
assert result["detail"][0]["msg"] == "Field required"
assert result["errorCode"] == "INVALID_CONNECTION_INFO"
assert "connectionInfo" in result["message"]
assert "connectionFilePath" in result["message"]


async def test_query_with_dry_run(
Expand Down
7 changes: 3 additions & 4 deletions ibis-server/tests/routers/v2/connector/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,9 @@ async def test_query_without_connection_info(client, manifest_str):
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "connectionInfo"]
assert result["detail"][0]["msg"] == "Field required"
assert result["errorCode"] == "INVALID_CONNECTION_INFO"
assert "connectionInfo" in result["message"]
assert "connectionFilePath" in result["message"]


async def test_query_with_dry_run(client, manifest_str, mssql: SqlServerContainer):
Expand Down
7 changes: 3 additions & 4 deletions ibis-server/tests/routers/v2/connector/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,9 @@ async def test_query_without_connection_info(client, manifest_str):
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "connectionInfo"]
assert result["detail"][0]["msg"] == "Field required"
assert result["errorCode"] == "INVALID_CONNECTION_INFO"
assert "connectionInfo" in result["message"]
assert "connectionFilePath" in result["message"]


async def test_query_with_dry_run(client, manifest_str, mysql: MySqlContainer):
Expand Down
7 changes: 3 additions & 4 deletions ibis-server/tests/routers/v2/connector/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,9 @@ async def test_query_without_connection_info(
)
assert response.status_code == 422
result = response.json()
assert result["detail"][0] is not None
assert result["detail"][0]["type"] == "missing"
assert result["detail"][0]["loc"] == ["body", "connectionInfo"]
assert result["detail"][0]["msg"] == "Field required"
assert result["errorCode"] == "INVALID_CONNECTION_INFO"
assert "connectionInfo" in result["message"]
assert "connectionFilePath" in result["message"]


async def test_query_with_dry_run(client, manifest_str, oracle: OracleDbContainer):
Expand Down
Loading