Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,9 @@ class LocalFileConnectionInfo(BaseConnectionInfo):
description="the root path of the local file", default="/", examples=["/data"]
)
format: str = Field(
description="File format", default="csv", examples=["csv", "parquet", "json"]
description="File format",
default="csv",
examples=["csv", "parquet", "json", "duckdb"],
)


Expand All @@ -387,7 +389,9 @@ class S3FileConnectionInfo(BaseConnectionInfo):
description="the root path of the s3 bucket", default="/", examples=["/data"]
)
format: str = Field(
description="File format", default="csv", examples=["csv", "parquet", "json"]
description="File format",
default="csv",
examples=["csv", "parquet", "json", "duckdb"],
)
bucket: SecretStr = Field(
description="the name of the s3 bucket", examples=["my-bucket"]
Expand All @@ -408,7 +412,9 @@ class MinioFileConnectionInfo(BaseConnectionInfo):
description="the root path of the minio bucket", default="/", examples=["/data"]
)
format: str = Field(
description="File format", default="csv", examples=["csv", "parquet", "json"]
description="File format",
default="csv",
examples=["csv", "parquet", "json", "duckdb"],
)
ssl_enabled: bool = Field(
description="use the ssl connection or not",
Expand All @@ -434,7 +440,9 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
description="the root path of the gcs bucket", default="/", examples=["/data"]
)
format: str = Field(
description="File format", default="csv", examples=["csv", "parquet", "json"]
description="File format",
default="csv",
examples=["csv", "parquet", "json", "duckdb"],
)
bucket: SecretStr = Field(
description="the name of the gcs bucket", examples=["my-bucket"]
Expand Down
54 changes: 52 additions & 2 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import importlib
import os
from contextlib import closing
from functools import cache
from json import loads
Expand All @@ -8,6 +9,7 @@
import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
import opendal
import pandas as pd
import pyarrow as pa
import sqlglot.expressions as sge
Expand Down Expand Up @@ -204,10 +206,23 @@ def __init__(self, connection_info: ConnectionInfo):
if isinstance(connection_info, GcsFileConnectionInfo):
init_duckdb_gcs(self.connection, connection_info)

if connection_info.format == "duckdb":
# For duckdb format, we attach the database files
self._attach_database(connection_info)

@tracer.start_as_current_span("duckdb_query", kind=trace.SpanKind.INTERNAL)
def query(self, sql: str, limit: int) -> pa.Table:
def query(self, sql: str, limit: int | None) -> pa.Table:
try:
return self.connection.execute(sql).fetch_arrow_table().slice(length=limit)
if limit is None:
# If no limit is specified, we return the full result
return self.connection.execute(sql).fetch_arrow_table()
else:
# If a limit is specified, we slice the result
# DuckDB does not support LIMIT in fetch_arrow_table, so we use slice
# to limit the number of rows returned
return (
self.connection.execute(sql).fetch_arrow_table().slice(length=limit)
)
except IOException as e:
raise UnprocessableEntityError(f"Failed to execute query: {e!s}")
except HTTPException as e:
Expand All @@ -222,6 +237,41 @@ def dry_run(self, sql: str) -> None:
except HTTPException as e:
raise QueryDryRunError(f"Failed to execute query: {e!s}")

def _attach_database(self, connection_info: ConnectionInfo) -> None:
db_files = self._list_duckdb_files(connection_info)
if not db_files:
raise UnprocessableEntityError(
"No DuckDB files found in the specified path."
)

for file in db_files:
try:
self.connection.execute(
f"ATTACH DATABASE '{file}' AS \"{os.path.splitext(os.path.basename(file))[0]}\" (READ_ONLY);"
)
except IOException as e:
raise UnprocessableEntityError(f"Failed to attach database: {e!s}")
except HTTPException as e:
raise UnprocessableEntityError(f"Failed to attach database: {e!s}")

def _list_duckdb_files(self, connection_info: ConnectionInfo) -> list[str]:
# This method should return a list of file paths in the DuckDB database
op = opendal.Operator("fs", root=connection_info.url.get_secret_value())
files = []
try:
for file in op.list("/"):
if file.path != "/":
stat = op.stat(file.path)
if not stat.mode.is_dir() and file.path.endswith(".duckdb"):
full_path = (
f"{connection_info.url.get_secret_value()}/{file.path}"
)
files.append(full_path)
except Exception as e:
raise UnprocessableEntityError(f"Failed to list files: {e!s}")

return files


class RedshiftConnector:
def __init__(self, connection_info: RedshiftConnectionUnion):
Expand Down
14 changes: 14 additions & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from app.model.metadata.mssql import MSSQLMetadata
from app.model.metadata.mysql import MySQLMetadata
from app.model.metadata.object_storage import (
DuckDBMetadata,
GcsFileMetadata,
LocalFileMetadata,
MinioFileMetadata,
Expand Down Expand Up @@ -41,6 +42,19 @@ class MetadataFactory:
@staticmethod
def get_metadata(data_source: DataSource, connection_info) -> Metadata:
try:
if (
data_source
in [
DataSource.local_file,
DataSource.s3_file,
DataSource.minio_file,
DataSource.gcs_file,
]
and connection_info.format == "duckdb"
):
# DuckDBMetadata is used for local file, S3, Minio, and GCS with DuckDB format
return DuckDBMetadata(connection_info)

return mapping[data_source](connection_info)
except KeyError:
raise NotImplementedError(f"Unsupported data source: {data_source}")
76 changes: 76 additions & 0 deletions ibis-server/app/model/metadata/object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import duckdb
import opendal
import pyarrow as pa
from loguru import logger

from app.model import (
Expand All @@ -11,6 +12,7 @@
S3FileConnectionInfo,
UnprocessableEntityError,
)
from app.model.connector import DuckDBConnector
from app.model.metadata.dto import (
Column,
RustWrenEngineColumnType,
Expand Down Expand Up @@ -271,3 +273,77 @@ def _get_full_path(self, path):
path = path[1:]

return f"gs://{self.connection_info.bucket.get_secret_value()}/{path}"


class DuckDBMetadata(ObjectStorageMetadata):
def __init__(self, connection_info: LocalFileConnectionInfo):
super().__init__(connection_info)
self.connection = DuckDBConnector(connection_info)

def get_table_list(self) -> list[Table]:
sql = """
SELECT
t.table_catalog,
t.table_schema,
t.table_name,
c.column_name,
c.data_type,
c.is_nullable,
c.ordinal_position
FROM
information_schema.tables t
JOIN
information_schema.columns c
ON t.table_schema = c.table_schema
AND t.table_name = c.table_name
WHERE
t.table_type IN ('BASE TABLE', 'VIEW')
AND t.table_schema NOT IN ('information_schema', 'pg_catalog');
"""
response = (
self.connection.query(sql, limit=None).to_pandas().to_dict(orient="records")
)

unique_tables = {}
for row in response:
# generate unique table name
schema_table = self._format_compact_table_name(
row["table_schema"], row["table_name"]
)
# init table if not exists
if schema_table not in unique_tables:
unique_tables[schema_table] = Table(
name=schema_table,
columns=[],
properties=TableProperties(
schema=row["table_schema"],
catalog=row["table_catalog"],
table=row["table_name"],
),
primaryKey="",
)

# table exists, and add column to the table
unique_tables[schema_table].columns.append(
Column(
name=row["column_name"],
type=self._to_column_type(row["data_type"]),
notNull=row["is_nullable"].lower() == "no",
properties=None,
)
)
return list(unique_tables.values())

def _format_compact_table_name(self, schema: str, table: str):
return f"{schema}.{table}"

def get_constraints(self):
return []

def get_version(self):
df: pa.Table = self.connection.query("SELECT version()")
if df is None:
raise UnprocessableEntityError("Failed to get DuckDB version")
if df.num_rows == 0:
raise UnprocessableEntityError("DuckDB version is empty")
return df.column(0).to_pylist()[0]
Binary file not shown.
33 changes: 33 additions & 0 deletions ibis-server/tests/routers/v2/connector/test_local_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,36 @@ async def test_list_json_files(client):
assert columns[21]["type"] == "UUID"
assert columns[22]["name"] == "c_varchar"
assert columns[22]["type"] == "STRING"


async def test_duckdb_metadata_list_tables(client):
response = await client.post(
url=f"{base_url}/metadata/tables",
json={
"connectionInfo": {
"url": "tests/resource/test_file_source",
"format": "duckdb",
},
},
)
assert response.status_code == 200

result = next(filter(lambda x: x["name"] == "main.customers", response.json()))
assert result["name"] == "main.customers"
assert result["primaryKey"] is not None
assert result["description"] is None
assert result["properties"] == {
"catalog": "jaffle_shop",
"schema": "main",
"table": "customers",
"path": None,
}
assert len(result["columns"]) == 7
assert result["columns"][1] == {
"name": "number_of_orders",
"nestedColumns": None,
"type": "INT64",
"notNull": False,
"description": None,
"properties": None,
}
43 changes: 43 additions & 0 deletions ibis-server/tests/routers/v3/connector/local_file/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,46 @@ async def test_dry_run(client, manifest_str):
)
assert response.status_code == 422
assert response.text is not None


async def test_query_duckdb_format(client):
manifest = {
"catalog": "wren",
"schema": "public",
"models": [
{
"name": "customers",
"tableReference": {
"catalog": "jaffle_shop",
"schema": "main",
"table": "customers",
},
"columns": [
{"name": "customer_id", "type": "integer"},
{"name": "customer_lifetime_value", "type": "double"},
{"name": "first_name", "type": "varchar"},
{"name": "first_order", "type": "date"},
{"name": "last_name", "type": "varchar"},
{"name": "most_recent_order", "type": "date"},
{"name": "number_of_orders", "type": "integer"},
],
},
],
"relationships": [],
"views": [],
}
response = await client.post(
f"{base_url}/query",
json={
"manifestStr": base64.b64encode(orjson.dumps(manifest)).decode("utf-8"),
"sql": "SELECT * FROM customers LIMIT 1",
"connectionInfo": {
"url": "tests/resource/test_file_source",
"format": "duckdb",
},
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["data"]) == 1