From 76c21ce20b9f282a84c5b128dabb3b1c14acf29d Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 Jan 2025 17:11:39 +0900 Subject: [PATCH 1/5] support s3 file connector --- ibis-server/app/model/__init__.py | 16 +++++++ ibis-server/app/model/connector.py | 14 +++++- ibis-server/app/model/data_source.py | 3 ++ ibis-server/app/model/metadata/factory.py | 3 +- .../app/model/metadata/object_storage.py | 44 ++++++++++++++++++- ibis-server/app/model/utils.py | 19 ++++++++ 6 files changed, 94 insertions(+), 5 deletions(-) create mode 100644 ibis-server/app/model/utils.py diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 933d2908c..24e93282a 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -56,6 +56,10 @@ class QueryLocalFileDTO(QueryDTO): connection_info: LocalFileConnectionInfo = connection_info_field +class QueryS3FileDTO(QueryDTO): + connection_info: S3FileConnectionInfo = connection_info_field + + class BigQueryConnectionInfo(BaseModel): project_id: SecretStr dataset_id: SecretStr @@ -147,6 +151,17 @@ class LocalFileConnectionInfo(BaseModel): ) +class S3FileConnectionInfo(BaseModel): + url: SecretStr = Field(description="the root path of the s3 bucket", default="/") + format: str = Field( + description="File format", default="csv", examples=["csv", "parquet", "json"] + ) + bucket: SecretStr + region: SecretStr + access_key: SecretStr + secret_key: SecretStr + + ConnectionInfo = ( BigQueryConnectionInfo | CannerConnectionInfo @@ -157,6 +172,7 @@ class LocalFileConnectionInfo(BaseModel): | SnowflakeConnectionInfo | TrinoConnectionInfo | LocalFileConnectionInfo + | S3FileConnectionInfo ) diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 307c40f09..fcb460410 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -15,8 +15,14 @@ from ibis import BaseBackend from ibis.backends.sql.compilers.postgres import compiler as postgres_compiler -from app.model import ConnectionInfo, UnknownIbisError, UnprocessableEntityError +from app.model import ( + ConnectionInfo, + S3FileConnectionInfo, + UnknownIbisError, + UnprocessableEntityError, +) from app.model.data_source import DataSource +from app.model.utils import init_duckdb_s3 # Override datatypes of ibis importlib.import_module("app.custom_ibis.backends.sql.datatypes") @@ -32,6 +38,8 @@ def __init__(self, data_source: DataSource, connection_info: ConnectionInfo): self._connector = BigQueryConnector(connection_info) elif data_source == DataSource.local_file: self._connector = DuckDBConnector(connection_info) + elif data_source == DataSource.s3_file: + self._connector = DuckDBConnector(connection_info) else: self._connector = SimpleConnector(data_source, connection_info) @@ -147,10 +155,12 @@ def query(self, sql: str, limit: int) -> pd.DataFrame: class DuckDBConnector: - def __init__(self, _connection_info: ConnectionInfo): + def __init__(self, connection_info: ConnectionInfo): import duckdb self.connection = duckdb.connect() + if isinstance(connection_info, S3FileConnectionInfo): + init_duckdb_s3(connection_info) def query(self, sql: str, limit: int) -> pd.DataFrame: return self.connection.execute(sql).fetch_df().head(limit) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 60b757251..d1aa3b41e 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -26,6 +26,7 @@ QueryMSSqlDTO, QueryMySqlDTO, QueryPostgresDTO, + QueryS3FileDTO, QuerySnowflakeDTO, QueryTrinoDTO, SnowflakeConnectionInfo, @@ -44,6 +45,7 @@ class DataSource(StrEnum): snowflake = auto() trino = auto() local_file = auto() + s3_file = auto() def get_connection(self, info: ConnectionInfo) -> BaseBackend: try: @@ -68,6 +70,7 @@ class DataSourceExtension(Enum): snowflake = QuerySnowflakeDTO trino = QueryTrinoDTO local_file = QueryLocalFileDTO + s3_file = QueryS3FileDTO def __init__(self, dto: QueryDTO): self.dto = dto diff --git a/ibis-server/app/model/metadata/factory.py b/ibis-server/app/model/metadata/factory.py index ad6dcb50f..89f4b5fdd 100644 --- a/ibis-server/app/model/metadata/factory.py +++ b/ibis-server/app/model/metadata/factory.py @@ -5,7 +5,7 @@ from app.model.metadata.metadata import Metadata from app.model.metadata.mssql import MSSQLMetadata from app.model.metadata.mysql import MySQLMetadata -from app.model.metadata.object_storage import LocalFileMetadata +from app.model.metadata.object_storage import LocalFileMetadata, S3FileMetadata from app.model.metadata.postgres import PostgresMetadata from app.model.metadata.snowflake import SnowflakeMetadata from app.model.metadata.trino import TrinoMetadata @@ -20,6 +20,7 @@ DataSource.trino: TrinoMetadata, DataSource.snowflake: SnowflakeMetadata, DataSource.local_file: LocalFileMetadata, + DataSource.s3_file: S3FileMetadata, } diff --git a/ibis-server/app/model/metadata/object_storage.py b/ibis-server/app/model/metadata/object_storage.py index 744523ee4..427ddfcf0 100644 --- a/ibis-server/app/model/metadata/object_storage.py +++ b/ibis-server/app/model/metadata/object_storage.py @@ -4,7 +4,7 @@ import opendal from loguru import logger -from app.model import LocalFileConnectionInfo +from app.model import LocalFileConnectionInfo, S3FileConnectionInfo from app.model.metadata.dto import ( Column, RustWrenEngineColumnType, @@ -12,6 +12,7 @@ TableProperties, ) from app.model.metadata.metadata import Metadata +from app.model.utils import init_duckdb_s3 class ObjectStorageMetadata(Metadata): @@ -19,7 +20,7 @@ def __init__(self, connection_info): super().__init__(connection_info) def get_table_list(self) -> list[Table]: - op = opendal.Operator("fs", root=self.connection_info.url.get_secret_value()) + op = self._get_dal_operator() conn = self._get_connection() unique_tables = {} for file in op.list("/"): @@ -36,6 +37,8 @@ def get_table_list(self) -> list[Table]: f"{self.connection_info.url.get_secret_value()}/{file.path}" ) + # add required prefix for object storage + full_path = self._get_full_path(full_path) # read the file with the target format if unreadable, skip the file df = self._read_df(conn, full_path) if df is None: @@ -147,6 +150,12 @@ def _to_column_type(self, col_type: str) -> RustWrenEngineColumnType: def _get_connection(self): return duckdb.connect() + def _get_dal_operator(self): + return opendal.Operator("fs", root=self.connection_info.url.get_secret_value()) + + def _get_full_path(self, path): + return path + class LocalFileMetadata(ObjectStorageMetadata): def __init__(self, connection_info: LocalFileConnectionInfo): @@ -154,3 +163,34 @@ def __init__(self, connection_info: LocalFileConnectionInfo): def get_version(self): return "Local File System" + + +class S3FileMetadata(ObjectStorageMetadata): + def __init__(self, connection_info): + super().__init__(connection_info) + + def get_version(self): + return "S3" + + def _get_connection(self): + conn = duckdb.connect() + init_duckdb_s3(conn, self.connection_info) + logger.debug("Initialized duckdb s3") + return conn + + def _get_dal_operator(self): + info: S3FileConnectionInfo = self.connection_info + return opendal.Operator( + "s3", + root=info.url.get_secret_value(), + bucket=info.bucket.get_secret_value(), + region=info.region.get_secret_value(), + secret_access_key=info.secret_key.get_secret_value(), + access_key_id=info.access_key.get_secret_value(), + ) + + def _get_full_path(self, path): + if path.startswith("/"): + path = path[1:] + + return f"s3://{self.connection_info.bucket.get_secret_value()}/{path}" diff --git a/ibis-server/app/model/utils.py b/ibis-server/app/model/utils.py new file mode 100644 index 000000000..c2414e593 --- /dev/null +++ b/ibis-server/app/model/utils.py @@ -0,0 +1,19 @@ +from duckdb import DuckDBPyConnection + +from app.model import S3FileConnectionInfo + + +def init_duckdb_s3( + connection: DuckDBPyConnection, connection_info: S3FileConnectionInfo +): + create_secret = f""" + CREATE SECRET wren_s3 ( + TYPE S3, + KEY_ID '{connection_info.access_key.get_secret_value()}', + SECRET '{connection_info.secret_key.get_secret_value()}', + REGION '{connection_info.region.get_secret_value()}' + ) + """ + result = connection.execute(create_secret).fetchone() + if result is None or not result[0]: + raise Exception("Failed to create secret") From 6d9c31bc9ab00d9279bb562153a6e40bc11da3bb Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 Jan 2025 18:58:24 +0900 Subject: [PATCH 2/5] add v2 test for s3 file --- .github/workflows/ibis-ci.yml | 4 + ibis-server/app/mdl/rewriter.py | 2 +- ibis-server/app/model/connector.py | 2 +- .../app/model/metadata/object_storage.py | 2 +- ibis-server/pyproject.toml | 1 + .../routers/v2/connector/test_s3_file.py | 473 ++++++++++++++++++ 6 files changed, 481 insertions(+), 3 deletions(-) create mode 100644 ibis-server/tests/routers/v2/connector/test_s3_file.py diff --git a/.github/workflows/ibis-ci.yml b/.github/workflows/ibis-ci.yml index 8c7e35ec8..2f14e4003 100644 --- a/.github/workflows/ibis-ci.yml +++ b/.github/workflows/ibis-ci.yml @@ -67,6 +67,10 @@ jobs: - name: Run tests env: WREN_ENGINE_ENDPOINT: http://localhost:8080 + AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} + AWS_REGION: ${{ secrets.AWS_REGION }} + AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }} run: poetry run pytest -m "not bigquery and not snowflake and not canner" - name: Test bigquery if need if: contains(github.event.pull_request.labels.*.name, 'bigquery') diff --git a/ibis-server/app/mdl/rewriter.py b/ibis-server/app/mdl/rewriter.py index e3ebfd0e6..361da3ac1 100644 --- a/ibis-server/app/mdl/rewriter.py +++ b/ibis-server/app/mdl/rewriter.py @@ -72,7 +72,7 @@ def _get_read_dialect(cls, experiment) -> str | None: def _get_write_dialect(cls, data_source: DataSource) -> str: if data_source == DataSource.canner: return "trino" - elif data_source == DataSource.local_file: + elif data_source in {DataSource.local_file, DataSource.s3_file}: return "duckdb" return data_source.name diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index fcb460410..0df662905 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -160,7 +160,7 @@ def __init__(self, connection_info: ConnectionInfo): self.connection = duckdb.connect() if isinstance(connection_info, S3FileConnectionInfo): - init_duckdb_s3(connection_info) + init_duckdb_s3(self.connection, connection_info) def query(self, sql: str, limit: int) -> pd.DataFrame: return self.connection.execute(sql).fetch_df().head(limit) diff --git a/ibis-server/app/model/metadata/object_storage.py b/ibis-server/app/model/metadata/object_storage.py index 427ddfcf0..1b383dcbe 100644 --- a/ibis-server/app/model/metadata/object_storage.py +++ b/ibis-server/app/model/metadata/object_storage.py @@ -166,7 +166,7 @@ def get_version(self): class S3FileMetadata(ObjectStorageMetadata): - def __init__(self, connection_info): + def __init__(self, connection_info: S3FileConnectionInfo): super().__init__(connection_info) def get_version(self): diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index a75fec302..f0132ce5d 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -64,6 +64,7 @@ markers = [ "snowflake: mark a test as a snowflake test", "trino: mark a test as a trino test", "local_file: mark a test as a local file test", + "s3_file: mark a test as a s3 file test", "beta: mark a test as a test for beta versions of the engine", ] diff --git a/ibis-server/tests/routers/v2/connector/test_s3_file.py b/ibis-server/tests/routers/v2/connector/test_s3_file.py new file mode 100644 index 000000000..5ed290cb2 --- /dev/null +++ b/ibis-server/tests/routers/v2/connector/test_s3_file.py @@ -0,0 +1,473 @@ +import base64 +import os + +import orjson +import pytest + +pytestmark = pytest.mark.s3_file + +access_key = os.getenv("AWS_ACCESS_KEY_ID") +secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") +region = os.getenv("AWS_REGION") +bucket = os.getenv("AWS_S3_BUCKET") + +base_url = "/v2/connector/s3_file" +manifest = { + "catalog": "my_calalog", + "schema": "my_schema", + "models": [ + { + "name": "Orders", + "tableReference": { + "table": f"s3://{bucket}/tpch/data/orders.parquet", + }, + "columns": [ + {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, + {"name": "custkey", "expression": "o_custkey", "type": "integer"}, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "o_totalprice", + "type": "float", + }, + {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, + { + "name": "order_cust_key", + "expression": "concat(o_orderkey, '_', o_custkey)", + "type": "varchar", + }, + ], + "primaryKey": "orderkey", + }, + { + "name": "Customer", + "tableReference": { + "table": f"s3://{bucket}/tpch/data/customer.parquet", + }, + "columns": [ + { + "name": "custkey", + "type": "integer", + "expression": "c_custkey", + }, + { + "name": "orders", + "type": "Orders", + "relationship": "CustomerOrders", + }, + { + "name": "sum_totalprice", + "type": "float", + "isCalculated": True, + "expression": "sum(orders.totalprice)", + }, + ], + "primaryKey": "custkey", + }, + ], + "relationships": [ + { + "name": "CustomerOrders", + "models": ["Customer", "Orders"], + "joinType": "ONE_TO_MANY", + "condition": "Customer.custkey = Orders.custkey", + } + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +@pytest.fixture(scope="module") +def connection_info() -> dict[str, str]: + return { + "url": "/tpch/data", + "format": "parquet", + "bucket": bucket, + "region": region, + "access_key": access_key, + "secret_key": secret_key, + } + + +async def test_query(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 1, + 370, + "O", + "172799.49", + "1996-01-02 00:00:00.000000", + "1_370", + ] + assert result["dtypes"] == { + "orderkey": "int32", + "custkey": "int32", + "orderstatus": "object", + "totalprice": "float64", + "orderdate": "object", + "order_cust_key": "object", + } + + +async def test_query_with_limit(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + params={"limit": 1}, + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" limit 2', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + +async def test_query_calculated_field(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT custkey, sum_totalprice FROM "Customer" WHERE custkey = 370', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == 2 + assert len(result["data"]) == 1 + assert result["data"][0] == [ + 370, + "2860895.79", + ] + assert result["dtypes"] == { + "custkey": "int32", + "sum_totalprice": "float64", + } + + +async def test_dry_run(client, manifest_str, connection_info): + response = await client.post( + f"{base_url}/query", + params={"dryRun": True}, + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 204 + + response = await client.post( + f"{base_url}/query", + params={"dryRun": True}, + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "NotFound" LIMIT 1', + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 422 + assert response.text is not None + + +async def test_metadata_list_tables(client, connection_info): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + + result = next(filter(lambda x: x["name"] == "orders", response.json())) + assert result["name"] == "orders" + assert result["primaryKey"] is None + assert result["description"] is None + assert result["properties"] == { + "catalog": None, + "schema": None, + "table": "orders", + "path": f"s3://{bucket}/tpch/data/orders.parquet", + } + assert len(result["columns"]) == 9 + assert result["columns"][8] == { + "name": "o_comment", + "nestedColumns": None, + "type": "STRING", + "notNull": False, + "description": None, + "properties": None, + } + + +async def test_metadata_list_constraints(client, connection_info): + response = await client.post( + url=f"{base_url}/metadata/constraints", + json={ + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + + +async def test_metadata_db_version(client, connection_info): + response = await client.post( + url=f"{base_url}/metadata/version", + json={ + "connectionInfo": connection_info, + }, + ) + assert response.status_code == 200 + assert "S3" in response.text + + +async def test_unsupported_format(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "url": "/tpch/data", + "format": "unsupported", + "bucket": bucket, + "region": region, + "access_key": access_key, + "secret_key": secret_key, + }, + }, + ) + assert response.status_code == 501 + assert response.text == "Unsupported format: unsupported" + + +async def test_list_parquet_files(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "url": "/test_file_source", + "format": "parquet", + "bucket": bucket, + "region": region, + "access_key": access_key, + "secret_key": secret_key, + }, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result) == 2 + table_names = [table["name"] for table in result] + assert "type-test-parquet" in table_names + assert "type-test" in table_names + columns = result[0]["columns"] + assert len(columns) == 23 + assert columns[0]["name"] == "c_bigint" + assert columns[0]["type"] == "INT64" + assert columns[1]["name"] == "c_bit" + assert columns[1]["type"] == "STRING" + assert columns[2]["name"] == "c_blob" + assert columns[2]["type"] == "BYTES" + assert columns[3]["name"] == "c_boolean" + assert columns[3]["type"] == "BOOL" + assert columns[4]["name"] == "c_date" + assert columns[4]["type"] == "DATE" + assert columns[5]["name"] == "c_double" + assert columns[5]["type"] == "DOUBLE" + assert columns[6]["name"] == "c_float" + assert columns[6]["type"] == "FLOAT" + assert columns[7]["name"] == "c_integer" + assert columns[7]["type"] == "INT" + assert columns[8]["name"] == "c_hugeint" + assert columns[8]["type"] == "DOUBLE" + assert columns[9]["name"] == "c_interval" + assert columns[9]["type"] == "INTERVAL" + assert columns[10]["name"] == "c_json" + assert columns[10]["type"] == "JSON" + assert columns[11]["name"] == "c_smallint" + assert columns[11]["type"] == "INT2" + assert columns[12]["name"] == "c_time" + assert columns[12]["type"] == "TIME" + assert columns[13]["name"] == "c_timestamp" + assert columns[13]["type"] == "TIMESTAMP" + assert columns[14]["name"] == "c_timestamptz" + assert columns[14]["type"] == "TIMESTAMPTZ" + assert columns[15]["name"] == "c_tinyint" + assert columns[15]["type"] == "INT2" + assert columns[16]["name"] == "c_ubigint" + assert columns[16]["type"] == "INT64" + assert columns[17]["name"] == "c_uhugeint" + assert columns[17]["type"] == "DOUBLE" + assert columns[18]["name"] == "c_uinteger" + assert columns[18]["type"] == "INT" + assert columns[19]["name"] == "c_usmallint" + assert columns[19]["type"] == "INT2" + assert columns[20]["name"] == "c_utinyint" + assert columns[20]["type"] == "INT2" + assert columns[21]["name"] == "c_uuid" + assert columns[21]["type"] == "UUID" + assert columns[22]["name"] == "c_varchar" + assert columns[22]["type"] == "STRING" + + +async def test_list_csv_files(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "url": "/test_file_source", + "format": "csv", + "bucket": bucket, + "region": region, + "access_key": access_key, + "secret_key": secret_key, + }, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result) == 3 + table_names = [table["name"] for table in result] + assert "type-test-csv" in table_names + assert "type-test" in table_names + # `invalid` will be considered as a one column csv file + assert "invalid" in table_names + columns = result[0]["columns"] + assert columns[0]["name"] == "c_bigint" + assert columns[0]["type"] == "INT64" + assert columns[1]["name"] == "c_bit" + assert columns[1]["type"] == "STRING" + assert columns[2]["name"] == "c_blob" + assert columns[2]["type"] == "STRING" + assert columns[3]["name"] == "c_boolean" + assert columns[3]["type"] == "BOOL" + assert columns[4]["name"] == "c_date" + assert columns[4]["type"] == "DATE" + assert columns[5]["name"] == "c_double" + assert columns[5]["type"] == "DOUBLE" + assert columns[6]["name"] == "c_float" + assert columns[6]["type"] == "DOUBLE" + assert columns[7]["name"] == "c_integer" + assert columns[7]["type"] == "INT64" + assert columns[8]["name"] == "c_hugeint" + assert columns[8]["type"] == "INT64" + assert columns[9]["name"] == "c_interval" + assert columns[9]["type"] == "STRING" + assert columns[10]["name"] == "c_json" + assert columns[10]["type"] == "STRING" + assert columns[11]["name"] == "c_smallint" + assert columns[11]["type"] == "INT64" + assert columns[12]["name"] == "c_time" + assert columns[12]["type"] == "TIME" + assert columns[13]["name"] == "c_timestamp" + assert columns[13]["type"] == "TIMESTAMP" + assert columns[14]["name"] == "c_timestamptz" + assert columns[14]["type"] == "TIMESTAMP" + assert columns[15]["name"] == "c_tinyint" + assert columns[15]["type"] == "INT64" + assert columns[16]["name"] == "c_ubigint" + assert columns[16]["type"] == "INT64" + assert columns[17]["name"] == "c_uhugeint" + assert columns[17]["type"] == "INT64" + assert columns[18]["name"] == "c_uinteger" + assert columns[18]["type"] == "INT64" + assert columns[19]["name"] == "c_usmallint" + assert columns[19]["type"] == "INT64" + assert columns[20]["name"] == "c_utinyint" + assert columns[20]["type"] == "INT64" + assert columns[21]["name"] == "c_uuid" + assert columns[21]["type"] == "STRING" + assert columns[22]["name"] == "c_varchar" + assert columns[22]["type"] == "STRING" + + +async def test_list_json_files(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "url": "/test_file_source", + "format": "json", + "bucket": bucket, + "region": region, + "access_key": access_key, + "secret_key": secret_key, + }, + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result) == 2 + table_names = [table["name"] for table in result] + assert "type-test-json" in table_names + assert "type-test" in table_names + + columns = result[0]["columns"] + assert columns[0]["name"] == "c_bigint" + assert columns[0]["type"] == "INT64" + # `c_bit` is a string in json which value is `00000000000000000000000000000001` + # It's considered as a UUID by DuckDB json reader. + assert columns[1]["name"] == "c_bit" + assert columns[1]["type"] == "UUID" + assert columns[2]["name"] == "c_blob" + assert columns[2]["type"] == "STRING" + assert columns[3]["name"] == "c_boolean" + assert columns[3]["type"] == "BOOL" + assert columns[4]["name"] == "c_date" + assert columns[4]["type"] == "DATE" + assert columns[5]["name"] == "c_double" + assert columns[5]["type"] == "DOUBLE" + assert columns[6]["name"] == "c_float" + assert columns[6]["type"] == "DOUBLE" + assert columns[7]["name"] == "c_integer" + assert columns[7]["type"] == "INT64" + assert columns[8]["name"] == "c_hugeint" + assert columns[8]["type"] == "DOUBLE" + assert columns[9]["name"] == "c_interval" + assert columns[9]["type"] == "STRING" + assert columns[10]["name"] == "c_json" + assert columns[10]["type"] == "UNKNOWN" + assert columns[11]["name"] == "c_smallint" + assert columns[11]["type"] == "INT64" + assert columns[12]["name"] == "c_time" + assert columns[12]["type"] == "TIME" + assert columns[13]["name"] == "c_timestamp" + assert columns[13]["type"] == "TIMESTAMP" + assert columns[14]["name"] == "c_timestamptz" + assert columns[14]["type"] == "STRING" + assert columns[15]["name"] == "c_tinyint" + assert columns[15]["type"] == "INT64" + assert columns[16]["name"] == "c_ubigint" + assert columns[16]["type"] == "INT64" + assert columns[17]["name"] == "c_uhugeint" + assert columns[17]["type"] == "DOUBLE" + assert columns[18]["name"] == "c_uinteger" + assert columns[18]["type"] == "INT64" + assert columns[19]["name"] == "c_usmallint" + assert columns[19]["type"] == "INT64" + assert columns[20]["name"] == "c_utinyint" + assert columns[20]["type"] == "INT64" + assert columns[21]["name"] == "c_uuid" + assert columns[21]["type"] == "UUID" + assert columns[22]["name"] == "c_varchar" + assert columns[22]["type"] == "STRING" From ea34d15eabdefa21940928c4db6fbb4ec0b74c48 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Thu, 16 Jan 2025 19:41:51 +0900 Subject: [PATCH 3/5] add negative tests --- ibis-server/app/model/connector.py | 11 +- .../app/model/metadata/object_storage.py | 101 ++++++++++-------- ibis-server/app/model/utils.py | 11 +- .../routers/v2/connector/test_s3_file.py | 34 ++++++ 4 files changed, 104 insertions(+), 53 deletions(-) diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 0df662905..1d7e55861 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -10,6 +10,7 @@ import ibis.formats import pandas as pd import sqlglot.expressions as sge +from duckdb import HTTPException from google.cloud import bigquery from google.oauth2 import service_account from ibis import BaseBackend @@ -163,10 +164,16 @@ def __init__(self, connection_info: ConnectionInfo): init_duckdb_s3(self.connection, connection_info) def query(self, sql: str, limit: int) -> pd.DataFrame: - return self.connection.execute(sql).fetch_df().head(limit) + try: + return self.connection.execute(sql).fetch_df().head(limit) + except HTTPException as e: + raise UnprocessableEntityError(f"Failed to execute query: {e!s}") def dry_run(self, sql: str) -> None: - self.connection.execute(sql) + try: + self.connection.execute(sql) + except HTTPException as e: + raise QueryDryRunError(f"Failed to execute query: {e!s}") @cache diff --git a/ibis-server/app/model/metadata/object_storage.py b/ibis-server/app/model/metadata/object_storage.py index 1b383dcbe..6b17cee91 100644 --- a/ibis-server/app/model/metadata/object_storage.py +++ b/ibis-server/app/model/metadata/object_storage.py @@ -4,7 +4,11 @@ import opendal from loguru import logger -from app.model import LocalFileConnectionInfo, S3FileConnectionInfo +from app.model import ( + LocalFileConnectionInfo, + S3FileConnectionInfo, + UnprocessableEntityError, +) from app.model.metadata.dto import ( Column, RustWrenEngineColumnType, @@ -23,54 +27,57 @@ def get_table_list(self) -> list[Table]: op = self._get_dal_operator() conn = self._get_connection() unique_tables = {} - for file in op.list("/"): - if file.path != "/": - stat = op.stat(file.path) - if stat.mode.is_dir(): - # if the file is a directory, use the directory name as the table name - table_name = os.path.basename(os.path.normpath(file.path)) - full_path = f"{self.connection_info.url.get_secret_value()}/{table_name}/*.{self.connection_info.format}" - else: - # if the file is a file, use the file name as the table name - table_name = os.path.splitext(os.path.basename(file.path))[0] - full_path = ( - f"{self.connection_info.url.get_secret_value()}/{file.path}" - ) + try: + for file in op.list("/"): + if file.path != "/": + stat = op.stat(file.path) + if stat.mode.is_dir(): + # if the file is a directory, use the directory name as the table name + table_name = os.path.basename(os.path.normpath(file.path)) + full_path = f"{self.connection_info.url.get_secret_value()}/{table_name}/*.{self.connection_info.format}" + else: + # if the file is a file, use the file name as the table name + table_name = os.path.splitext(os.path.basename(file.path))[0] + full_path = ( + f"{self.connection_info.url.get_secret_value()}/{file.path}" + ) - # add required prefix for object storage - full_path = self._get_full_path(full_path) - # read the file with the target format if unreadable, skip the file - df = self._read_df(conn, full_path) - if df is None: - continue - columns = [] - try: - for col in df.columns: - duckdb_type = df[col].dtypes[0] - columns.append( - Column( - name=col, - type=self._to_column_type(duckdb_type.__str__()), - notNull=False, + # add required prefix for object storage + full_path = self._get_full_path(full_path) + # read the file with the target format if unreadable, skip the file + df = self._read_df(conn, full_path) + if df is None: + continue + columns = [] + try: + for col in df.columns: + duckdb_type = df[col].dtypes[0] + columns.append( + Column( + name=col, + type=self._to_column_type(duckdb_type.__str__()), + notNull=False, + ) ) - ) - except Exception as e: - logger.debug(f"Failed to read column types: {e}") - continue - - unique_tables[table_name] = Table( - name=table_name, - description=None, - columns=[], - properties=TableProperties( - table=table_name, - schema=None, - catalog=None, - path=full_path, - ), - primaryKey=None, - ) - unique_tables[table_name].columns = columns + except Exception as e: + logger.debug(f"Failed to read column types: {e}") + continue + + unique_tables[table_name] = Table( + name=table_name, + description=None, + columns=[], + properties=TableProperties( + table=table_name, + schema=None, + catalog=None, + path=full_path, + ), + primaryKey=None, + ) + unique_tables[table_name].columns = columns + except Exception as e: + raise UnprocessableEntityError(f"Failed to list files: {e!s}") return list(unique_tables.values()) diff --git a/ibis-server/app/model/utils.py b/ibis-server/app/model/utils.py index c2414e593..5dfe36f64 100644 --- a/ibis-server/app/model/utils.py +++ b/ibis-server/app/model/utils.py @@ -1,4 +1,4 @@ -from duckdb import DuckDBPyConnection +from duckdb import DuckDBPyConnection, HTTPException from app.model import S3FileConnectionInfo @@ -14,6 +14,9 @@ def init_duckdb_s3( REGION '{connection_info.region.get_secret_value()}' ) """ - result = connection.execute(create_secret).fetchone() - if result is None or not result[0]: - raise Exception("Failed to create secret") + try: + result = connection.execute(create_secret).fetchone() + if result is None or not result[0]: + raise Exception("Failed to create secret") + except HTTPException as e: + raise Exception("Failed to create secret", e) diff --git a/ibis-server/tests/routers/v2/connector/test_s3_file.py b/ibis-server/tests/routers/v2/connector/test_s3_file.py index 5ed290cb2..3ea8898ff 100644 --- a/ibis-server/tests/routers/v2/connector/test_s3_file.py +++ b/ibis-server/tests/routers/v2/connector/test_s3_file.py @@ -191,6 +191,40 @@ async def test_dry_run(client, manifest_str, connection_info): assert response.text is not None +async def test_query_with_invalid_connection_info(client, manifest_str): + response = await client.post( + f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + "connectionInfo": { + "url": "/tpch/data", + "format": "parquet", + "bucket": bucket, + "region": region, + "access_key": "invalid", + "secret_key": "invalid", + }, + }, + ) + assert response.status_code == 422 + + response = await client.post( + url=f"{base_url}/metadata/tables", + json={ + "connectionInfo": { + "url": "/tpch/data", + "format": "parquet", + "bucket": bucket, + "region": region, + "access_key": "invalid", + "secret_key": "invalid", + }, + }, + ) + assert response.status_code == 422 + + async def test_metadata_list_tables(client, connection_info): response = await client.post( url=f"{base_url}/metadata/tables", From bc167c5070656ba1c4bb3801880f27fcb7c04726 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 17 Jan 2025 15:00:25 +0900 Subject: [PATCH 4/5] fix negative tests --- ibis-server/tests/routers/v2/connector/test_local_file.py | 4 ++-- ibis-server/tests/routers/v2/connector/test_s3_file.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ibis-server/tests/routers/v2/connector/test_local_file.py b/ibis-server/tests/routers/v2/connector/test_local_file.py index a4c394d70..e882307e2 100644 --- a/ibis-server/tests/routers/v2/connector/test_local_file.py +++ b/ibis-server/tests/routers/v2/connector/test_local_file.py @@ -243,8 +243,8 @@ async def test_unsupported_format(client): }, }, ) - assert response.status_code == 501 - assert response.text == "Unsupported format: unsupported" + assert response.status_code == 422 + assert response.text == "Failed to list files: Unsupported format: unsupported" async def test_list_parquet_files(client): diff --git a/ibis-server/tests/routers/v2/connector/test_s3_file.py b/ibis-server/tests/routers/v2/connector/test_s3_file.py index 3ea8898ff..8e855fc9b 100644 --- a/ibis-server/tests/routers/v2/connector/test_s3_file.py +++ b/ibis-server/tests/routers/v2/connector/test_s3_file.py @@ -290,8 +290,8 @@ async def test_unsupported_format(client): }, }, ) - assert response.status_code == 501 - assert response.text == "Unsupported format: unsupported" + assert response.status_code == 422 + assert response.text == "Failed to list files: Unsupported format: unsupported" async def test_list_parquet_files(client): From 26d9d3f1f402103ef714a1345213cc145689b8a1 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Fri, 17 Jan 2025 16:07:27 +0900 Subject: [PATCH 5/5] disable s3_file test in ci --- .github/workflows/ibis-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ibis-ci.yml b/.github/workflows/ibis-ci.yml index 2f14e4003..829c8577b 100644 --- a/.github/workflows/ibis-ci.yml +++ b/.github/workflows/ibis-ci.yml @@ -71,7 +71,7 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AWS_REGION: ${{ secrets.AWS_REGION }} AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }} - run: poetry run pytest -m "not bigquery and not snowflake and not canner" + run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file" - name: Test bigquery if need if: contains(github.event.pull_request.labels.*.name, 'bigquery') env: