diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index af9fadefd..6f7692075 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -99,6 +99,10 @@ class QueryLocalFileDTO(QueryDTO): connection_info: LocalFileConnectionInfo = connection_info_field +class QueryDuckDBDTO(QueryDTO): + connection_info: LocalFileConnectionInfo = connection_info_field + + class QueryS3FileDTO(QueryDTO): connection_info: S3FileConnectionInfo = connection_info_field diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index c86ff6877..be9f52ee8 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -92,6 +92,7 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): DataSource.s3_file, DataSource.minio_file, DataSource.gcs_file, + DataSource.duckdb, }: self._connector = DuckDBConnector(connection_info) elif data_source == DataSource.redshift: diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index c30aa8827..beb6b00a9 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -37,6 +37,7 @@ QueryClickHouseDTO, QueryDatabricksDTO, QueryDTO, + QueryDuckDBDTO, QueryGcsFileDTO, QueryLocalFileDTO, QueryMinioFileDTO, @@ -78,6 +79,7 @@ class DataSource(StrEnum): s3_file = auto() minio_file = auto() gcs_file = auto() + duckdb = auto() spark = auto() databricks = auto() @@ -179,6 +181,8 @@ def _build_connection_info(self, data: dict) -> ConnectionInfo: return SnowflakeConnectionInfo.model_validate(data) case DataSource.trino: return TrinoConnectionInfo.model_validate(data) + case DataSource.duckdb: + return LocalFileConnectionInfo.model_validate(data) case DataSource.local_file: return LocalFileConnectionInfo.model_validate(data) case DataSource.s3_file: @@ -242,6 +246,7 @@ class DataSourceExtension(Enum): snowflake = QuerySnowflakeDTO trino = QueryTrinoDTO local_file = QueryLocalFileDTO + duckdb = QueryDuckDBDTO s3_file = QueryS3FileDTO minio_file = QueryMinioFileDTO gcs_file = QueryGcsFileDTO @@ -256,7 +261,7 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend: if hasattr(info, "connection_url"): kwargs = info.kwargs if info.kwargs else {} return ibis.connect(info.connection_url.get_secret_value(), **kwargs) - if self.name in {"local_file", "redshift", "spark"}: + if self.name in {"local_file", "redshift", "spark", "duckdb"}: raise NotImplementedError( f"{self.name} connection is not implemented to get ibis backend" ) diff --git a/ibis-server/app/model/metadata/factory.py b/ibis-server/app/model/metadata/factory.py index 4c8d49c6d..7725de9fc 100644 --- a/ibis-server/app/model/metadata/factory.py +++ b/ibis-server/app/model/metadata/factory.py @@ -39,6 +39,7 @@ DataSource.gcs_file: GcsFileMetadata, DataSource.databricks: DatabricksMetadata, DataSource.spark: SparkMetadata, + DataSource.duckdb: DuckDBMetadata, } diff --git a/ibis-server/app/routers/v3/connector.py b/ibis-server/app/routers/v3/connector.py index c44d0d38c..54ac347a2 100644 --- a/ibis-server/app/routers/v3/connector.py +++ b/ibis-server/app/routers/v3/connector.py @@ -32,7 +32,13 @@ from app.model.connector import Connector from app.model.data_source import DataSource from app.model.error import DatabaseTimeoutError -from app.model.metadata.dto import Catalog, MetadataDTO, Table, get_filter_info +from app.model.metadata.dto import ( + Catalog, + Constraint, + MetadataDTO, + Table, + get_filter_info, +) from app.model.metadata.factory import MetadataFactory from app.model.validator import Validator from app.query_cache import QueryCacheManager @@ -42,6 +48,7 @@ append_fallback_context, build_context, execute_dry_run_with_timeout, + execute_get_constraints_with_timeout, execute_get_schema_list_with_timeout, execute_get_table_list_with_timeout, execute_query_with_timeout, @@ -605,3 +612,25 @@ async def get_schema_list( filter_info, dto.table_limit, ) + + +@router.post( + "/{data_source}/metadata/constraints", + response_model=list[Constraint], + description="get the constraints of the specified data source", +) +async def get_constraints( + data_source: DataSource, + dto: MetadataDTO, + headers: Annotated[Headers, Depends(get_wren_headers)], +) -> list[Constraint]: + span_name = f"v3_metadata_constraints_{data_source}" + with tracer.start_as_current_span( + name=span_name, kind=trace.SpanKind.SERVER, context=build_context(headers) + ) as span: + set_attribute(headers, span) + connection_info = data_source.get_connection_info( + resolve_connection_info(dto), dict(headers) + ) + metadata = MetadataFactory.get_metadata(data_source, connection_info) + return await execute_get_constraints_with_timeout(metadata) diff --git a/ibis-server/app/util.py b/ibis-server/app/util.py index d5bf61a61..ba96ba7c0 100644 --- a/ibis-server/app/util.py +++ b/ibis-server/app/util.py @@ -1,7 +1,6 @@ import asyncio import base64 import json -import pathlib import time try: @@ -78,19 +77,28 @@ def resolve_connection_info(dto) -> dict: ErrorCode.INVALID_CONNECTION_INFO, "connectionFilePath requires the CONNECTION_FILE_ROOT environment variable to be set", ) - allowed_root_resolved = str(pathlib.Path(allowed_root).resolve()) - path = pathlib.Path(dto.connection_file_path).resolve() - # Explicit string prefix check — recognized by static analysis as a path sanitizer - if ( - not str(path).startswith(allowed_root_resolved + os.sep) - and str(path) != allowed_root_resolved - ): + # Resolve the trusted root (no user input involved) + allowed_root_str = os.path.realpath(allowed_root) + # Build the candidate path by joining the trusted root with the user + # value, then normalise. Using os.path.normpath(os.path.join(base, user)) + # is the pattern recognised by CodeQL as safe for path-injection checks. + # realpath additionally resolves symlinks so a symlink inside the allowed + # root cannot escape to a file outside it. + fullpath = os.path.realpath( + os.path.normpath(os.path.join(allowed_root_str, dto.connection_file_path)) + ) + root_prefix = ( + allowed_root_str + if allowed_root_str.endswith(os.sep) + else allowed_root_str + os.sep + ) + if not fullpath.startswith(root_prefix): raise WrenError( ErrorCode.INVALID_CONNECTION_INFO, f"Connection file path is outside the allowed directory: {dto.connection_file_path}", ) try: - with open(path) as f: + with open(fullpath) as f: return _normalize_port(json.load(f)) except FileNotFoundError: raise WrenError( diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index d2bca6f05..94df1eb7c 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -104,6 +104,7 @@ markers = [ "trino: mark a test as a trino test", "databricks: mark a test as a databricks test", "spark: mark a test as a spark test", + "duckdb: mark a test as a duckdb test", "local_file: mark a test as a local file test", "s3_file: mark a test as a s3 file test", "minio_file: mark a test as a minio file test", diff --git a/ibis-server/tests/resource/duckdb/jaffle_shop.duckdb b/ibis-server/tests/resource/duckdb/jaffle_shop.duckdb new file mode 100644 index 000000000..ca7ed9367 Binary files /dev/null and b/ibis-server/tests/resource/duckdb/jaffle_shop.duckdb differ diff --git a/ibis-server/tests/routers/v3/connector/duckdb/__init__.py b/ibis-server/tests/routers/v3/connector/duckdb/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ibis-server/tests/routers/v3/connector/duckdb/conftest.py b/ibis-server/tests/routers/v3/connector/duckdb/conftest.py new file mode 100644 index 000000000..596af0364 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/duckdb/conftest.py @@ -0,0 +1,25 @@ +import pathlib + +import pytest + +pytestmark = pytest.mark.duckdb + +base_url = "/v3/connector/duckdb" + + +def pytest_collection_modifyitems(items): + current_file_dir = pathlib.Path(__file__).resolve().parent + for item in items: + try: + pathlib.Path(item.fspath).relative_to(current_file_dir) + item.add_marker(pytestmark) + except ValueError: + pass + + +@pytest.fixture(scope="module") +def connection_info() -> dict[str, str]: + return { + "url": "tests/resource/duckdb", + "format": "duckdb", + } diff --git a/ibis-server/tests/routers/v3/connector/duckdb/test_metadata.py b/ibis-server/tests/routers/v3/connector/duckdb/test_metadata.py new file mode 100644 index 000000000..0f5d32163 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/duckdb/test_metadata.py @@ -0,0 +1,44 @@ +from tests.routers.v3.connector.duckdb.conftest import base_url + +v3_base_url = base_url + + +async def test_metadata_list_tables(client, connection_info): + response = await client.post( + url=f"{v3_base_url}/metadata/tables", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + + tables = response.json() + assert len(tables) > 0 + + result = next( + filter(lambda x: x["name"] == "main.customers", tables), + None, + ) + assert result is not None + assert result["primaryKey"] == "" + assert result["properties"] == { + "catalog": "jaffle_shop", + "schema": "main", + "table": "customers", + "path": None, + } + assert len(result["columns"]) > 0 + + customer_id_col = next( + filter(lambda c: c["name"] == "customer_id", result["columns"]), None + ) + assert customer_id_col is not None + assert customer_id_col["nestedColumns"] is None + assert customer_id_col["properties"] is None + + +async def test_metadata_list_constraints(client, connection_info): + response = await client.post( + url=f"{v3_base_url}/metadata/constraints", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + assert response.json() == [] diff --git a/ibis-server/tests/routers/v3/connector/duckdb/test_query.py b/ibis-server/tests/routers/v3/connector/duckdb/test_query.py new file mode 100644 index 000000000..aac8fb61e --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/duckdb/test_query.py @@ -0,0 +1,137 @@ +import base64 + +import orjson +import pytest + +from tests.routers.v3.connector.duckdb.conftest import base_url + +manifest = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "customers", + "tableReference": { + "catalog": "jaffle_shop", + "schema": "main", + "table": "customers", + }, + "columns": [ + {"name": "customer_id", "type": "integer"}, + {"name": "first_name", "type": "varchar"}, + {"name": "last_name", "type": "varchar"}, + {"name": "first_order", "type": "date"}, + {"name": "most_recent_order", "type": "date"}, + {"name": "number_of_orders", "type": "bigint"}, + {"name": "customer_lifetime_value", "type": "double"}, + ], + "primaryKey": "customer_id", + }, + { + "name": "orders", + "tableReference": { + "catalog": "jaffle_shop", + "schema": "main", + "table": "orders", + }, + "columns": [ + {"name": "order_id", "type": "integer"}, + {"name": "customer_id", "type": "integer"}, + {"name": "order_date", "type": "date"}, + {"name": "status", "type": "varchar"}, + {"name": "amount", "type": "double"}, + ], + "primaryKey": "order_id", + }, + ], + "relationships": [ + { + "name": "CustomersOrders", + "models": ["customers", "orders"], + "joinType": "ONE_TO_MANY", + "condition": '"customers".customer_id = "orders".customer_id', + } + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_query(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT customer_id, first_name, last_name FROM "customers" ORDER BY customer_id LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert result["columns"] == ["customer_id", "first_name", "last_name"] + assert len(result["data"]) == 1 + assert result["data"][0] == [1, "Michael", "P."] + assert result["dtypes"] == { + "customer_id": "int32", + "first_name": "string", + "last_name": "string", + } + + +async def test_query_with_limit(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + params={"limit": 1}, + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "customers" LIMIT 5', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + +async def test_query_orders(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT order_id, customer_id, status, amount FROM "orders" ORDER BY order_id LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert result["columns"] == ["order_id", "customer_id", "status", "amount"] + assert len(result["data"]) == 1 + assert result["data"][0] == [1, 1, "returned", 10.0] + + +async def test_dry_run(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + params={"dryRun": True}, + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "customers" LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 204 + + response = await client.post( + f"{base_url}/query", + params={"dryRun": True}, + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "NotFound" LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 422 + assert response.text is not None