Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class QueryLocalFileDTO(QueryDTO):
connection_info: LocalFileConnectionInfo = connection_info_field


class QueryDuckDBDTO(QueryDTO):
connection_info: LocalFileConnectionInfo = connection_info_field


class QueryS3FileDTO(QueryDTO):
connection_info: S3FileConnectionInfo = connection_info_field

Expand Down
1 change: 1 addition & 0 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo):
DataSource.s3_file,
DataSource.minio_file,
DataSource.gcs_file,
DataSource.duckdb,
}:
self._connector = DuckDBConnector(connection_info)
elif data_source == DataSource.redshift:
Expand Down
7 changes: 6 additions & 1 deletion ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
QueryClickHouseDTO,
QueryDatabricksDTO,
QueryDTO,
QueryDuckDBDTO,
QueryGcsFileDTO,
QueryLocalFileDTO,
QueryMinioFileDTO,
Expand Down Expand Up @@ -78,6 +79,7 @@ class DataSource(StrEnum):
s3_file = auto()
minio_file = auto()
gcs_file = auto()
duckdb = auto()
spark = auto()
databricks = auto()

Expand Down Expand Up @@ -179,6 +181,8 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo:
return SnowflakeConnectionInfo.model_validate(data)
case DataSource.trino:
return TrinoConnectionInfo.model_validate(data)
case DataSource.duckdb:
return LocalFileConnectionInfo.model_validate(data)
case DataSource.local_file:
return LocalFileConnectionInfo.model_validate(data)
case DataSource.s3_file:
Expand Down Expand Up @@ -242,6 +246,7 @@ class DataSourceExtension(Enum):
snowflake = QuerySnowflakeDTO
trino = QueryTrinoDTO
local_file = QueryLocalFileDTO
duckdb = QueryDuckDBDTO
s3_file = QueryS3FileDTO
minio_file = QueryMinioFileDTO
gcs_file = QueryGcsFileDTO
Expand All @@ -256,7 +261,7 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend:
if hasattr(info, "connection_url"):
kwargs = info.kwargs if info.kwargs else {}
return ibis.connect(info.connection_url.get_secret_value(), **kwargs)
if self.name in {"local_file", "redshift", "spark"}:
if self.name in {"local_file", "redshift", "spark", "duckdb"}:
raise NotImplementedError(
f"{self.name} connection is not implemented to get ibis backend"
)
Expand Down
1 change: 1 addition & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
DataSource.gcs_file: GcsFileMetadata,
DataSource.databricks: DatabricksMetadata,
DataSource.spark: SparkMetadata,
DataSource.duckdb: DuckDBMetadata,
}


Expand Down
31 changes: 30 additions & 1 deletion ibis-server/app/routers/v3/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from app.model.connector import Connector
from app.model.data_source import DataSource
from app.model.error import DatabaseTimeoutError
from app.model.metadata.dto import Catalog, MetadataDTO, Table, get_filter_info
from app.model.metadata.dto import (
Catalog,
Constraint,
MetadataDTO,
Table,
get_filter_info,
)
from app.model.metadata.factory import MetadataFactory
from app.model.validator import Validator
from app.query_cache import QueryCacheManager
Expand All @@ -42,6 +48,7 @@
append_fallback_context,
build_context,
execute_dry_run_with_timeout,
execute_get_constraints_with_timeout,
execute_get_schema_list_with_timeout,
execute_get_table_list_with_timeout,
execute_query_with_timeout,
Expand Down Expand Up @@ -605,3 +612,25 @@ async def get_schema_list(
filter_info,
dto.table_limit,
)


@router.post(
"/{data_source}/metadata/constraints",
response_model=list[Constraint],
description="get the constraints of the specified data source",
)
async def get_constraints(
data_source: DataSource,
dto: MetadataDTO,
headers: Annotated[Headers, Depends(get_wren_headers)],
) -> list[Constraint]:
span_name = f"v3_metadata_constraints_{data_source}"
with tracer.start_as_current_span(
name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers)
) as span:
set_attribute(headers, span)
connection_info = data_source.get_connection_info(
resolve_connection_info(dto), dict(headers)
)
metadata = MetadataFactory.get_metadata(data_source, connection_info)
return await execute_get_constraints_with_timeout(metadata)
26 changes: 17 additions & 9 deletions ibis-server/app/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import base64
import json
import pathlib
import time

try:
Expand Down Expand Up @@ -78,19 +77,28 @@ def resolve_connection_info(dto) -> dict:
ErrorCode.INVALID_CONNECTION_INFO,
"connectionFilePath requires the CONNECTION_FILE_ROOT environment variable to be set",
)
allowed_root_resolved = str(pathlib.Path(allowed_root).resolve())
path = pathlib.Path(dto.connection_file_path).resolve()
# Explicit string prefix check — recognized by static analysis as a path sanitizer
if (
not str(path).startswith(allowed_root_resolved + os.sep)
and str(path) != allowed_root_resolved
):
# Resolve the trusted root (no user input involved)
allowed_root_str = os.path.realpath(allowed_root)
# Build the candidate path by joining the trusted root with the user
# value, then normalise. Using os.path.normpath(os.path.join(base, user))
# is the pattern recognised by CodeQL as safe for path-injection checks.
# realpath additionally resolves symlinks so a symlink inside the allowed
# root cannot escape to a file outside it.
fullpath = os.path.realpath(
os.path.normpath(os.path.join(allowed_root_str, dto.connection_file_path))
)
root_prefix = (
allowed_root_str
if allowed_root_str.endswith(os.sep)
else allowed_root_str + os.sep
)
if not fullpath.startswith(root_prefix):
raise WrenError(
ErrorCode.INVALID_CONNECTION_INFO,
f"Connection file path is outside the allowed directory: {dto.connection_file_path}",
)
try:
with open(path) as f:
with open(fullpath) as f:
return _normalize_port(json.load(f))
except FileNotFoundError:
raise WrenError(
Expand Down
1 change: 1 addition & 0 deletions ibis-server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ markers = [
"trino: mark a test as a trino test",
"databricks: mark a test as a databricks test",
"spark: mark a test as a spark test",
"duckdb: mark a test as a duckdb test",
"local_file: mark a test as a local file test",
"s3_file: mark a test as a s3 file test",
"minio_file: mark a test as a minio file test",
Expand Down
Binary file not shown.
Empty file.
25 changes: 25 additions & 0 deletions ibis-server/tests/routers/v3/connector/duckdb/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pathlib

import pytest

pytestmark = pytest.mark.duckdb

base_url = "/v3/connector/duckdb"


def pytest_collection_modifyitems(items):
current_file_dir = pathlib.Path(__file__).resolve().parent
for item in items:
try:
pathlib.Path(item.fspath).relative_to(current_file_dir)
item.add_marker(pytestmark)
except ValueError:
pass


@pytest.fixture(scope="module")
def connection_info() -> dict[str, str]:
return {
"url": "tests/resource/duckdb",
"format": "duckdb",
}
44 changes: 44 additions & 0 deletions ibis-server/tests/routers/v3/connector/duckdb/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from tests.routers.v3.connector.duckdb.conftest import base_url

v3_base_url = base_url


async def test_metadata_list_tables(client, connection_info):
response = await client.post(
url=f"{v3_base_url}/metadata/tables",
json={"connectionInfo": connection_info},
)
assert response.status_code == 200

tables = response.json()
assert len(tables) > 0

result = next(
filter(lambda x: x["name"] == "main.customers", tables),
None,
)
assert result is not None
assert result["primaryKey"] == ""
assert result["properties"] == {
"catalog": "jaffle_shop",
"schema": "main",
"table": "customers",
"path": None,
}
assert len(result["columns"]) > 0

customer_id_col = next(
filter(lambda c: c["name"] == "customer_id", result["columns"]), None
)
assert customer_id_col is not None
assert customer_id_col["nestedColumns"] is None
assert customer_id_col["properties"] is None


async def test_metadata_list_constraints(client, connection_info):
response = await client.post(
url=f"{v3_base_url}/metadata/constraints",
json={"connectionInfo": connection_info},
)
assert response.status_code == 200
assert response.json() == []
137 changes: 137 additions & 0 deletions ibis-server/tests/routers/v3/connector/duckdb/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import base64

import orjson
import pytest

from tests.routers.v3.connector.duckdb.conftest import base_url

manifest = {
"catalog": "wren",
"schema": "public",
"models": [
{
"name": "customers",
"tableReference": {
"catalog": "jaffle_shop",
"schema": "main",
"table": "customers",
},
"columns": [
{"name": "customer_id", "type": "integer"},
{"name": "first_name", "type": "varchar"},
{"name": "last_name", "type": "varchar"},
{"name": "first_order", "type": "date"},
{"name": "most_recent_order", "type": "date"},
{"name": "number_of_orders", "type": "bigint"},
{"name": "customer_lifetime_value", "type": "double"},
],
"primaryKey": "customer_id",
},
{
"name": "orders",
"tableReference": {
"catalog": "jaffle_shop",
"schema": "main",
"table": "orders",
},
"columns": [
{"name": "order_id", "type": "integer"},
{"name": "customer_id", "type": "integer"},
{"name": "order_date", "type": "date"},
{"name": "status", "type": "varchar"},
{"name": "amount", "type": "double"},
],
"primaryKey": "order_id",
},
],
"relationships": [
{
"name": "CustomersOrders",
"models": ["customers", "orders"],
"joinType": "ONE_TO_MANY",
"condition": '"customers".customer_id = "orders".customer_id',
}
],
}


@pytest.fixture(scope="module")
def manifest_str():
return base64.b64encode(orjson.dumps(manifest)).decode("utf-8")


async def test_query(client, manifest_str, connection_info):
response = await client.post(
f"{base_url}/query",
json={
"manifestStr": manifest_str,
"sql": 'SELECT customer_id, first_name, last_name FROM "customers" ORDER BY customer_id LIMIT 1',
"connectionInfo": connection_info,
},
)
assert response.status_code == 200
result = response.json()
assert result["columns"] == ["customer_id", "first_name", "last_name"]
assert len(result["data"]) == 1
assert result["data"][0] == [1, "Michael", "P."]
assert result["dtypes"] == {
"customer_id": "int32",
"first_name": "string",
"last_name": "string",
}


async def test_query_with_limit(client, manifest_str, connection_info):
response = await client.post(
f"{base_url}/query",
params={"limit": 1},
json={
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "customers" LIMIT 5',
"connectionInfo": connection_info,
},
)
assert response.status_code == 200
result = response.json()
assert len(result["data"]) == 1


async def test_query_orders(client, manifest_str, connection_info):
response = await client.post(
f"{base_url}/query",
json={
"manifestStr": manifest_str,
"sql": 'SELECT order_id, customer_id, status, amount FROM "orders" ORDER BY order_id LIMIT 1',
"connectionInfo": connection_info,
},
)
assert response.status_code == 200
result = response.json()
assert result["columns"] == ["order_id", "customer_id", "status", "amount"]
assert len(result["data"]) == 1
assert result["data"][0] == [1, 1, "returned", 10.0]


async def test_dry_run(client, manifest_str, connection_info):
response = await client.post(
f"{base_url}/query",
params={"dryRun": True},
json={
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "customers" LIMIT 1',
"connectionInfo": connection_info,
},
)
assert response.status_code == 204

response = await client.post(
f"{base_url}/query",
params={"dryRun": True},
json={
"manifestStr": manifest_str,
"sql": 'SELECT * FROM "NotFound" LIMIT 1',
"connectionInfo": connection_info,
},
)
assert response.status_code == 422
assert response.text is not None