diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index c4b06edd6..96f2260f3 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -378,7 +378,9 @@ class LocalFileConnectionInfo(BaseConnectionInfo): description="the root path of the local file", default="/", examples=["/data"] ) format: str = Field( - description="File format", default="csv", examples=["csv", "parquet", "json"] + description="File format", + default="csv", + examples=["csv", "parquet", "json", "duckdb"], ) @@ -387,7 +389,9 @@ class S3FileConnectionInfo(BaseConnectionInfo): description="the root path of the s3 bucket", default="/", examples=["/data"] ) format: str = Field( - description="File format", default="csv", examples=["csv", "parquet", "json"] + description="File format", + default="csv", + examples=["csv", "parquet", "json", "duckdb"], ) bucket: SecretStr = Field( description="the name of the s3 bucket", examples=["my-bucket"] @@ -408,7 +412,9 @@ class MinioFileConnectionInfo(BaseConnectionInfo): description="the root path of the minio bucket", default="/", examples=["/data"] ) format: str = Field( - description="File format", default="csv", examples=["csv", "parquet", "json"] + description="File format", + default="csv", + examples=["csv", "parquet", "json", "duckdb"], ) ssl_enabled: bool = Field( description="use the ssl connection or not", @@ -434,7 +440,9 @@ class GcsFileConnectionInfo(BaseConnectionInfo): description="the root path of the gcs bucket", default="/", examples=["/data"] ) format: str = Field( - description="File format", default="csv", examples=["csv", "parquet", "json"] + description="File format", + default="csv", + examples=["csv", "parquet", "json", "duckdb"], ) bucket: SecretStr = Field( description="the name of the gcs bucket", examples=["my-bucket"] diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 39864acf8..35834bad2 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -1,5 +1,6 @@ import base64 import importlib +import os from contextlib import closing from functools import cache from json import loads @@ -8,6 +9,7 @@ import ibis import ibis.expr.datatypes as dt import ibis.expr.schema as sch +import opendal import pandas as pd import pyarrow as pa import sqlglot.expressions as sge @@ -204,10 +206,23 @@ def __init__(self, connection_info: ConnectionInfo): if isinstance(connection_info, GcsFileConnectionInfo): init_duckdb_gcs(self.connection, connection_info) + if connection_info.format == "duckdb": + # For duckdb format, we attach the database files + self._attach_database(connection_info) + @tracer.start_as_current_span("duckdb_query", kind=trace.SpanKind.INTERNAL) - def query(self, sql: str, limit: int) -> pa.Table: + def query(self, sql: str, limit: int | None) -> pa.Table: try: - return self.connection.execute(sql).fetch_arrow_table().slice(length=limit) + if limit is None: + # If no limit is specified, we return the full result + return self.connection.execute(sql).fetch_arrow_table() + else: + # If a limit is specified, we slice the result + # DuckDB does not support LIMIT in fetch_arrow_table, so we use slice + # to limit the number of rows returned + return ( + self.connection.execute(sql).fetch_arrow_table().slice(length=limit) + ) except IOException as e: raise UnprocessableEntityError(f"Failed to execute query: {e!s}") except HTTPException as e: @@ -222,6 +237,41 @@ def dry_run(self, sql: str) -> None: except HTTPException as e: raise QueryDryRunError(f"Failed to execute query: {e!s}") + def _attach_database(self, connection_info: ConnectionInfo) -> None: + db_files = self._list_duckdb_files(connection_info) + if not db_files: + raise UnprocessableEntityError( + "No DuckDB files found in the specified path." + ) + + for file in db_files: + try: + self.connection.execute( + f"ATTACH DATABASE '{file}' AS \"{os.path.splitext(os.path.basename(file))[0]}\" (READ_ONLY);" + ) + except IOException as e: + raise UnprocessableEntityError(f"Failed to attach database: {e!s}") + except HTTPException as e: + raise UnprocessableEntityError(f"Failed to attach database: {e!s}") + + def _list_duckdb_files(self, connection_info: ConnectionInfo) -> list[str]: + # This method should return a list of file paths in the DuckDB database + op = opendal.Operator("fs", root=connection_info.url.get_secret_value()) + files = [] + try: + for file in op.list("/"): + if file.path != "/": + stat = op.stat(file.path) + if not stat.mode.is_dir() and file.path.endswith(".duckdb"): + full_path = ( + f"{connection_info.url.get_secret_value()}/{file.path}" + ) + files.append(full_path) + except Exception as e: + raise UnprocessableEntityError(f"Failed to list files: {e!s}") + + return files + class RedshiftConnector: def __init__(self, connection_info: RedshiftConnectionUnion): diff --git a/ibis-server/app/model/metadata/factory.py b/ibis-server/app/model/metadata/factory.py index 7f78d2a00..911d42747 100644 --- a/ibis-server/app/model/metadata/factory.py +++ b/ibis-server/app/model/metadata/factory.py @@ -7,6 +7,7 @@ from app.model.metadata.mssql import MSSQLMetadata from app.model.metadata.mysql import MySQLMetadata from app.model.metadata.object_storage import ( + DuckDBMetadata, GcsFileMetadata, LocalFileMetadata, MinioFileMetadata, @@ -41,6 +42,19 @@ class MetadataFactory: @staticmethod def get_metadata(data_source: DataSource, connection_info) -> Metadata: try: + if ( + data_source + in [ + DataSource.local_file, + DataSource.s3_file, + DataSource.minio_file, + DataSource.gcs_file, + ] + and connection_info.format == "duckdb" + ): + # DuckDBMetadata is used for local file, S3, Minio, and GCS with DuckDB format + return DuckDBMetadata(connection_info) + return mapping[data_source](connection_info) except KeyError: raise NotImplementedError(f"Unsupported data source: {data_source}") diff --git a/ibis-server/app/model/metadata/object_storage.py b/ibis-server/app/model/metadata/object_storage.py index 8ed8f5759..24da66ee0 100644 --- a/ibis-server/app/model/metadata/object_storage.py +++ b/ibis-server/app/model/metadata/object_storage.py @@ -2,6 +2,7 @@ import duckdb import opendal +import pyarrow as pa from loguru import logger from app.model import ( @@ -11,6 +12,7 @@ S3FileConnectionInfo, UnprocessableEntityError, ) +from app.model.connector import DuckDBConnector from app.model.metadata.dto import ( Column, RustWrenEngineColumnType, @@ -271,3 +273,77 @@ def _get_full_path(self, path): path = path[1:] return f"gs://{self.connection_info.bucket.get_secret_value()}/{path}" + + +class DuckDBMetadata(ObjectStorageMetadata): + def __init__(self, connection_info: LocalFileConnectionInfo): + super().__init__(connection_info) + self.connection = DuckDBConnector(connection_info) + + def get_table_list(self) -> list[Table]: + sql = """ + SELECT + t.table_catalog, + t.table_schema, + t.table_name, + c.column_name, + c.data_type, + c.is_nullable, + c.ordinal_position + FROM + information_schema.tables t + JOIN + information_schema.columns c + ON t.table_schema = c.table_schema + AND t.table_name = c.table_name + WHERE + t.table_type IN ('BASE TABLE', 'VIEW') + AND t.table_schema NOT IN ('information_schema', 'pg_catalog'); + """ + response = ( + self.connection.query(sql, limit=None).to_pandas().to_dict(orient="records") + ) + + unique_tables = {} + for row in response: + # generate unique table name + schema_table = self._format_compact_table_name( + row["table_schema"], row["table_name"] + ) + # init table if not exists + if schema_table not in unique_tables: + unique_tables[schema_table] = Table( + name=schema_table, + columns=[], + properties=TableProperties( + schema=row["table_schema"], + catalog=row["table_catalog"], + table=row["table_name"], + ), + primaryKey="", + ) + + # table exists, and add column to the table + unique_tables[schema_table].columns.append( + Column( + name=row["column_name"], + type=self._to_column_type(row["data_type"]), + notNull=row["is_nullable"].lower() == "no", + properties=None, + ) + ) + return list(unique_tables.values()) + + def _format_compact_table_name(self, schema: str, table: str): + return f"{schema}.{table}" + + def get_constraints(self): + return [] + + def get_version(self): + df: pa.Table = self.connection.query("SELECT version()") + if df is None: + raise UnprocessableEntityError("Failed to get DuckDB version") + if df.num_rows == 0: + raise UnprocessableEntityError("DuckDB version is empty") + return df.column(0).to_pylist()[0] diff --git a/ibis-server/tests/resource/test_file_source/jaffle_shop.duckdb b/ibis-server/tests/resource/test_file_source/jaffle_shop.duckdb new file mode 100644 index 000000000..2a226b840 Binary files /dev/null and b/ibis-server/tests/resource/test_file_source/jaffle_shop.duckdb differ diff --git a/ibis-server/tests/routers/v2/connector/test_local_file.py b/ibis-server/tests/routers/v2/connector/test_local_file.py index 73c854335..4055d1b63 100644 --- a/ibis-server/tests/routers/v2/connector/test_local_file.py +++ b/ibis-server/tests/routers/v2/connector/test_local_file.py @@ -447,3 +447,36 @@ async def test_list_json_files(client): assert columns[21]["type"] == "UUID" assert columns[22]["name"] == "c_varchar" assert columns[22]["type"] == "STRING" + + +async def test_duckdb_metadata_list_tables(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "url": "tests/resource/test_file_source", + "format": "duckdb", + }, + }, + ) + assert response.status_code == 200 + + result = next(filter(lambda x: x["name"] == "main.customers", response.json())) + assert result["name"] == "main.customers" + assert result["primaryKey"] is not None + assert result["description"] is None + assert result["properties"] == { + "catalog": "jaffle_shop", + "schema": "main", + "table": "customers", + "path": None, + } + assert len(result["columns"]) == 7 + assert result["columns"][1] == { + "name": "number_of_orders", + "nestedColumns": None, + "type": "INT64", + "notNull": False, + "description": None, + "properties": None, + } diff --git a/ibis-server/tests/routers/v3/connector/local_file/test_query.py b/ibis-server/tests/routers/v3/connector/local_file/test_query.py index 3b0111fe3..e3174c8d8 100644 --- a/ibis-server/tests/routers/v3/connector/local_file/test_query.py +++ b/ibis-server/tests/routers/v3/connector/local_file/test_query.py @@ -178,3 +178,46 @@ async def test_dry_run(client, manifest_str): ) assert response.status_code == 422 assert response.text is not None + + +async def test_query_duckdb_format(client): + manifest = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "customers", + "tableReference": { + "catalog": "jaffle_shop", + "schema": "main", + "table": "customers", + }, + "columns": [ + {"name": "customer_id", "type": "integer"}, + {"name": "customer_lifetime_value", "type": "double"}, + {"name": "first_name", "type": "varchar"}, + {"name": "first_order", "type": "date"}, + {"name": "last_name", "type": "varchar"}, + {"name": "most_recent_order", "type": "date"}, + {"name": "number_of_orders", "type": "integer"}, + ], + }, + ], + "relationships": [], + "views": [], + } + response = await client.post( + f"{base_url}/query", + json={ + "manifestStr": base64.b64encode(orjson.dumps(manifest)).decode("utf-8"), + "sql": "SELECT * FROM customers LIMIT 1", + "connectionInfo": { + "url": "tests/resource/test_file_source", + "format": "duckdb", + }, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1