From a3f1a6ef3db354ba207c96f01331ecefb7f4a673 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Mon, 2 Jun 2025 10:10:35 +0800 Subject: [PATCH 1/9] athena support --- .github/workflows/ibis-ci.yml | 8 + ibis-server/app/model/__init__.py | 28 +- ibis-server/app/model/data_source.py | 14 + ibis-server/app/model/metadata/athena.py | 139 +++++++ ibis-server/app/model/metadata/factory.py | 2 + ibis-server/poetry.lock | 120 +++++- ibis-server/pyproject.toml | 2 + .../tests/routers/v2/connector/test_athena.py | 351 ++++++++++++++++++ 8 files changed, 655 insertions(+), 9 deletions(-) create mode 100644 ibis-server/app/model/metadata/athena.py create mode 100644 ibis-server/tests/routers/v2/connector/test_athena.py diff --git a/.github/workflows/ibis-ci.yml b/.github/workflows/ibis-ci.yml index 53266b89b..34b8fe307 100644 --- a/.github/workflows/ibis-ci.yml +++ b/.github/workflows/ibis-ci.yml @@ -90,3 +90,11 @@ jobs: SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }} SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }} run: just test snowflake + - name: Test athena if need + if: contains(github.event.pull_request.labels.*.name, 'athena') + env: + WREN_ENGINE_ENDPOINT: http://localhost:8080 + TEST_ATHENA_S3_STAGING_DIR: ${{ secrets.TEST_ATHENA_S3_STAGING_DIR }} + TEST_ATHENA_AWS_ACCESS_KEY_ID: ${{ secrets.TEST_ATHENA_AWS_ACCESS_KEY_ID }} + TEST_ATHENA_AWS_SECRET_ACCESS_KEY: ${{ secrets.TEST_ATHENA_AWS_SECRET_ACCESS_KEY }} + run: just test athena diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index af4b46964..4f714b7a8 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -36,6 +36,10 @@ class QueryBigQueryDTO(QueryDTO): connection_info: BigQueryConnectionInfo = connection_info_field +class QueryAthenaDTO(QueryDTO): + connection_info: AthenaConnectionInfo = connection_info_field + + class QueryCannerDTO(QueryDTO): connection_info: ConnectionUrl | CannerConnectionInfo = connection_info_field @@ -98,6 +102,27 @@ class BigQueryConnectionInfo(BaseConnectionInfo): ) +class AthenaConnectionInfo(BaseConnectionInfo): + s3_staging_dir: SecretStr = Field( + description="S3 staging directory for Athena queries", + examples=["s3://my-bucket/athena-staging/"], + ) + aws_access_key_id: SecretStr = Field( + description="AWS access key ID", examples=["AKIA..."] + ) + aws_secret_access_key: SecretStr = Field( + description="AWS secret access key", examples=["my-secret-key"] + ) + region_name: SecretStr = Field( + description="AWS region for Athena", examples=["us-west-2", "us-east-1"] + ) + schema_name: SecretStr = Field( + alias="schema_name", + description="The database name in Athena", + examples=["default"], + ) + + class CannerConnectionInfo(BaseConnectionInfo): host: SecretStr = Field( description="the hostname of your database", examples=["localhost"] @@ -339,7 +364,8 @@ class GcsFileConnectionInfo(BaseConnectionInfo): ConnectionInfo = ( - BigQueryConnectionInfo + AthenaConnectionInfo + | BigQueryConnectionInfo | CannerConnectionInfo | ConnectionUrl | MSSqlConnectionInfo diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 4328fec6b..dfc832440 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -11,6 +11,7 @@ from ibis import BaseBackend from app.model import ( + AthenaConnectionInfo, BigQueryConnectionInfo, CannerConnectionInfo, ClickHouseConnectionInfo, @@ -19,6 +20,7 @@ MySqlConnectionInfo, OracleConnectionInfo, PostgresConnectionInfo, + QueryAthenaDTO, QueryBigQueryDTO, QueryCannerDTO, QueryClickHouseDTO, @@ -40,6 +42,7 @@ class DataSource(StrEnum): + athena = auto() bigquery = auto() canner = auto() clickhouse = auto() @@ -68,6 +71,7 @@ def get_dto_type(self): class DataSourceExtension(Enum): + athena = QueryAthenaDTO bigquery = QueryBigQueryDTO canner = QueryCannerDTO clickhouse = QueryClickHouseDTO @@ -97,6 +101,16 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend: except KeyError: raise NotImplementedError(f"Unsupported data source: {self}") + @staticmethod + def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: + return ibis.athena.connect( + s3_staging_dir=info.s3_staging_dir.get_secret_value(), + aws_access_key_id=info.aws_access_key_id.get_secret_value(), + aws_secret_access_key=info.aws_secret_access_key.get_secret_value(), + region_name=info.region_name.get_secret_value(), + schema_name=info.schema_name.get_secret_value(), + ) + @staticmethod def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend: credits_json = loads( diff --git a/ibis-server/app/model/metadata/athena.py b/ibis-server/app/model/metadata/athena.py new file mode 100644 index 000000000..b342e831c --- /dev/null +++ b/ibis-server/app/model/metadata/athena.py @@ -0,0 +1,139 @@ +import re + +import pandas as pd + +from app.model import AthenaConnectionInfo +from app.model.data_source import DataSource +from app.model.metadata.dto import ( + Column, + Constraint, + RustWrenEngineColumnType, + Table, + TableProperties, +) +from app.model.metadata.metadata import Metadata + + +class AthenaMetadata(Metadata): + def __init__(self, connection_info: AthenaConnectionInfo): + super().__init__(connection_info) + self.connection = DataSource.athena.get_connection(connection_info) + + def get_table_list(self) -> list[Table]: + schema_name = self.connection_info.schema_name.get_secret_value() + + sql = f""" + SELECT + t.table_catalog, + t.table_schema, + t.table_name, + c.column_name, + c.ordinal_position, + c.is_nullable, + c.data_type + FROM + information_schema.tables AS t + JOIN + information_schema.columns AS c + ON t.table_catalog = c.table_catalog + AND t.table_schema = c.table_schema + AND t.table_name = c.table_name + WHERE t.table_schema = '{schema_name}' + ORDER BY t.table_name + """ + + # We need to use raw_sql here because using the sql method causes Ibis to *create view* first, + # which does not work with information_schema queries. + cursor = self.connection.raw_sql(sql) + response = pd.DataFrame( + cursor.fetchall(), columns=[col[0] for col in cursor.description] + ).to_dict(orient="records") + + def get_column(row) -> Column: + return Column( + name=row["column_name"], + type=self._transform_column_type(row["data_type"]), + notNull=row["is_nullable"].lower() == "no", + description="", # Athena doesn't provide column descriptions in information_schema + properties=None, + ) + + def get_table(row) -> Table: + return Table( + name=self._format_athena_compact_table_name( + row["table_schema"], row["table_name"] + ), + description="", # Athena doesn't provide table descriptions in information_schema + columns=[], + properties=TableProperties( + schema=row["table_schema"], + catalog=row["table_catalog"], + table=row["table_name"], + ), + primaryKey="", + ) + + unique_tables = {} + + for column_metadata in response: + # generate unique table name + table_name = self._format_athena_compact_table_name( + column_metadata["table_schema"], column_metadata["table_name"] + ) + # init table if not exists + if table_name not in unique_tables: + unique_tables[table_name] = get_table(column_metadata) + + current_table = unique_tables[table_name] + # add column to table + current_table.columns.append(get_column(column_metadata)) + + return list(unique_tables.values()) + + def get_constraints(self) -> list[Constraint]: + # Athena doesn't support foreign key constraints + return [] + + def get_version(self) -> str: + return "AWS Athena - Follow AWS service versioning" + + def _format_athena_compact_table_name(self, schema: str, table: str) -> str: + return f"{schema}.{table}" + + def _transform_column_type(self, data_type): + data_type = re.sub(r"\(.*\)", "", data_type).strip() + switcher = { + # String Types (ignore Binary and Spatial Types for now) + "char": RustWrenEngineColumnType.CHAR, + "varchar": RustWrenEngineColumnType.VARCHAR, + "tinytext": RustWrenEngineColumnType.TEXT, + "text": RustWrenEngineColumnType.TEXT, + "mediumtext": RustWrenEngineColumnType.TEXT, + "longtext": RustWrenEngineColumnType.TEXT, + "enum": RustWrenEngineColumnType.VARCHAR, + "set": RustWrenEngineColumnType.VARCHAR, + # Numeric Types(https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html) + "bit": RustWrenEngineColumnType.TINYINT, + "tinyint": RustWrenEngineColumnType.TINYINT, + "smallint": RustWrenEngineColumnType.SMALLINT, + "mediumint": RustWrenEngineColumnType.INTEGER, + "int": RustWrenEngineColumnType.INTEGER, + "integer": RustWrenEngineColumnType.INTEGER, + "bigint": RustWrenEngineColumnType.BIGINT, + # boolean + "bool": RustWrenEngineColumnType.BOOL, + "boolean": RustWrenEngineColumnType.BOOL, + # Decimal + "float": RustWrenEngineColumnType.FLOAT4, + "double": RustWrenEngineColumnType.DOUBLE, + "decimal": RustWrenEngineColumnType.DECIMAL, + "numeric": RustWrenEngineColumnType.NUMERIC, + # Date and Time Types(https://dev.mysql.com/doc/refman/8.4/en/date-and-time-types.html) + "date": RustWrenEngineColumnType.DATE, + "datetime": RustWrenEngineColumnType.TIMESTAMP, + "timestamp": RustWrenEngineColumnType.TIMESTAMPTZ, + # JSON Type + "json": RustWrenEngineColumnType.JSON, + } + + return switcher.get(data_type.lower(), RustWrenEngineColumnType.UNKNOWN) diff --git a/ibis-server/app/model/metadata/factory.py b/ibis-server/app/model/metadata/factory.py index a49f58ba9..86ca85d69 100644 --- a/ibis-server/app/model/metadata/factory.py +++ b/ibis-server/app/model/metadata/factory.py @@ -1,4 +1,5 @@ from app.model.data_source import DataSource +from app.model.metadata.athena import AthenaMetadata from app.model.metadata.bigquery import BigQueryMetadata from app.model.metadata.canner import CannerMetadata from app.model.metadata.clickhouse import ClickHouseMetadata @@ -17,6 +18,7 @@ from app.model.metadata.trino import TrinoMetadata mapping = { + DataSource.athena: AthenaMetadata, DataSource.bigquery: BigQueryMetadata, DataSource.canner: CannerMetadata, DataSource.clickhouse: ClickHouseMetadata, diff --git a/ibis-server/poetry.lock b/ibis-server/poetry.lock index cc7df8814..2faf17fd2 100644 --- a/ibis-server/poetry.lock +++ b/ibis-server/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -1229,6 +1229,49 @@ files = [ {file = "frozenlist-1.6.0.tar.gz", hash = "sha256:b99655c32c1c8e06d111e7f41c06c29a5318cb1835df23a45518e02a47c63b68"}, ] +[[package]] +name = "fsspec" +version = "2025.5.1" +description = "File-system specification" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "fsspec-2025.5.1-py3-none-any.whl", hash = "sha256:24d3a2e663d5fc735ab256263c4075f374a174c3410c0b25e5bd1970bceaa462"}, + {file = "fsspec-2025.5.1.tar.gz", hash = "sha256:2e55e47a540b91843b755e83ded97c6e897fa0942b11490113f09e9c443c2475"}, +] + +[package.dependencies] +s3fs = {version = "*", optional = true, markers = "extra == \"s3\""} + +[package.extras] +abfs = ["adlfs"] +adl = ["adlfs"] +arrow = ["pyarrow (>=1)"] +dask = ["dask", "distributed"] +dev = ["pre-commit", "ruff"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] +dropbox = ["dropbox", "dropboxdrivefs", "requests"] +full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] +fuse = ["fusepy"] +gcs = ["gcsfs"] +git = ["pygit2"] +github = ["requests"] +gs = ["gcsfs"] +gui = ["panel"] +hdfs = ["pyarrow (>=1)"] +http = ["aiohttp (!=4.0.0a0,!=4.0.0a1)"] +libarchive = ["libarchive-c"] +oci = ["ocifs"] +s3 = ["s3fs"] +sftp = ["paramiko"] +smb = ["smbprotocol"] +ssh = ["paramiko"] +test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] +test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] +tqdm = ["tqdm"] + [[package]] name = "geoalchemy2" version = "0.17.1" @@ -1848,22 +1891,24 @@ files = [ atpublic = ">=2.3" clickhouse-connect = {version = ">=0.5.23", extras = ["arrow", "numpy", "pandas"], optional = true, markers = "extra == \"clickhouse\""} db-dtypes = {version = ">=0.3", optional = true, markers = "extra == \"bigquery\""} +fsspec = {version = "*", extras = ["s3"], optional = true, markers = "extra == \"athena\""} google-cloud-bigquery = {version = ">=3", optional = true, markers = "extra == \"bigquery\""} google-cloud-bigquery-storage = {version = ">=2", optional = true, markers = "extra == \"bigquery\""} mysqlclient = {version = ">=2.2.4", optional = true, markers = "extra == \"mysql\""} -numpy = {version = ">=1.23.2,<3", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} +numpy = {version = ">=1.23.2,<3", optional = true, markers = "extra == \"athena\" or extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} oracledb = {version = ">=1.3.1", optional = true, markers = "extra == \"oracle\""} -packaging = {version = ">=21.3", optional = true, markers = "extra == \"oracle\""} -pandas = {version = ">=1.5.3,<3", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} +packaging = {version = ">=21.3", optional = true, markers = "extra == \"athena\" or extra == \"oracle\""} +pandas = {version = ">=1.5.3,<3", optional = true, markers = "extra == \"athena\" or extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} pandas-gbq = {version = ">=0.26.1", optional = true, markers = "extra == \"bigquery\""} parsy = ">=2" psycopg = {version = ">=3.2.0", extras = ["binary"], optional = true, markers = "extra == \"postgres\""} -pyarrow = {version = ">=10.0.1", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} -pyarrow-hotfix = {version = ">=0.4", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} +pyarrow = {version = ">=10.0.1", optional = true, markers = "extra == \"athena\" or extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} +pyarrow-hotfix = {version = ">=0.4", optional = true, markers = "extra == \"athena\" or extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} +pyathena = {version = ">=3.11.0", extras = ["arrow", "pandas"], optional = true, markers = "extra == \"athena\""} pydata-google-auth = {version = ">=1.4.0", optional = true, markers = "extra == \"bigquery\""} pyodbc = {version = ">=4.0.39", optional = true, markers = "extra == \"mssql\""} python-dateutil = ">=2.8.2" -rich = {version = ">=12.4.4", optional = true, markers = "extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} +rich = {version = ">=12.4.4", optional = true, markers = "extra == \"athena\" or extra == \"bigquery\" or extra == \"clickhouse\" or extra == \"mssql\" or extra == \"mysql\" or extra == \"oracle\" or extra == \"postgres\" or extra == \"snowflake\" or extra == \"trino\""} snowflake-connector-python = {version = ">=3.0.2,<3.3.0b1 || >3.3.0b1", optional = true, markers = "extra == \"snowflake\""} sqlglot = ">=23.4" toolz = ">=0.11" @@ -3192,6 +3237,33 @@ files = [ [package.dependencies] pyasn1 = ">=0.6.1,<0.7.0" +[[package]] +name = "pyathena" +version = "3.14.1" +description = "Python DB API 2.0 (PEP 249) client for Amazon Athena" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "pyathena-3.14.1-py3-none-any.whl", hash = "sha256:cdba338da81cc982d9babdbf801631655a5fd4ce9bf4e44a257efa431d891b36"}, + {file = "pyathena-3.14.1.tar.gz", hash = "sha256:ff628261595b52bc2b74809c42ef89886f74be01371506e289fdb0bc4f653993"}, +] + +[package.dependencies] +boto3 = ">=1.26.4" +botocore = ">=1.29.4" +fsspec = "*" +pandas = {version = ">=1.3.0", optional = true, markers = "extra == \"pandas\""} +pyarrow = {version = ">=7.0.0", optional = true, markers = "extra == \"arrow\""} +python-dateutil = "*" +tenacity = ">=4.1.0" + +[package.extras] +arrow = ["pyarrow (>=7.0.0)"] +fastparquet = ["fastparquet (>=0.4.0)"] +pandas = ["pandas (>=1.3.0)"] +sqlalchemy = ["sqlalchemy (>=1.0.0)"] + [[package]] name = "pycparser" version = "2.22" @@ -3972,6 +4044,22 @@ files = [ {file = "ruff-0.11.2.tar.gz", hash = "sha256:ec47591497d5a1050175bdf4e1a4e6272cddff7da88a2ad595e1e326041d8d94"}, ] +[[package]] +name = "s3fs" +version = "0.4.2" +description = "Convenient Filesystem interface over S3" +optional = false +python-versions = ">= 3.5" +groups = ["main"] +files = [ + {file = "s3fs-0.4.2-py3-none-any.whl", hash = "sha256:91c1dfb45e5217bd441a7a560946fe865ced6225ff7eb0fb459fe6e601a95ed3"}, + {file = "s3fs-0.4.2.tar.gz", hash = "sha256:2ca5de8dc18ad7ad350c0bd01aef0406aa5d0fff78a561f0f710f9d9858abdd0"}, +] + +[package.dependencies] +botocore = ">=1.12.91" +fsspec = ">=0.6.0" + [[package]] name = "s3transfer" version = "0.13.0" @@ -4391,6 +4479,22 @@ anyio = ">=3.6.2,<5" [package.extras] full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] +[[package]] +name = "tenacity" +version = "9.1.2" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138"}, + {file = "tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "testcontainers" version = "4.9.2" @@ -5221,4 +5325,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "e96635dc13ca9f20bae206feb05624ad9f60f943b459d284577fe146b8655873" +content-hash = "066754ad51a081f624a2b216ec600453ba6829b756901194c70d151428110311" diff --git a/ibis-server/pyproject.toml b/ibis-server/pyproject.toml index c227fc95d..42d4e97a0 100644 --- a/ibis-server/pyproject.toml +++ b/ibis-server/pyproject.toml @@ -11,6 +11,7 @@ python = ">=3.11,<3.12" fastapi = { version = "0.115.12", extras = ["standard"] } pydantic = "2.10.6" ibis-framework = { version = "10.3.1", extras = [ + "athena", "bigquery", "clickhouse", "mssql", @@ -66,6 +67,7 @@ asgi-lifespan = "2.1.0" [tool.pytest.ini_options] addopts = ["--strict-markers"] markers = [ + "athena: mark a test as an athena test", "bigquery: mark a test as a bigquery test", "canner: mark a test as a canner test", "clickhouse: mark a test as a clickhouse test", diff --git a/ibis-server/tests/routers/v2/connector/test_athena.py b/ibis-server/tests/routers/v2/connector/test_athena.py new file mode 100644 index 000000000..cabdaa203 --- /dev/null +++ b/ibis-server/tests/routers/v2/connector/test_athena.py @@ -0,0 +1,351 @@ +import base64 +import os + +import orjson +import pytest + +from app.model.validator import rules + +pytestmark = pytest.mark.athena + +base_url = "/v2/connector/athena" + +connection_info = { + "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), + "aws_access_key_id": os.getenv("TEST_ATHENA_AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": os.getenv("TEST_ATHENA_AWS_SECRET_ACCESS_KEY"), + "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "schema_name": "test", +} +manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "Orders", + "refSql": "select * from test.orders", + "columns": [ + {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, + {"name": "custkey", "expression": "o_custkey", "type": "integer"}, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "o_totalprice", + "type": "float", + }, + {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, + { + "name": "order_cust_key", + "expression": "concat(cast(o_orderkey as varchar), '_', cast(o_custkey as varchar))", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "TIMESTAMP '2024-01-01 23:59:59'", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "TIMESTAMP '2024-01-01 23:59:59 UTC'", + "type": "timestamp", + }, + { + "name": "test_null_time", + "expression": "cast(NULL as timestamp)", + "type": "timestamp", + }, + { + "name": "bytea_column", + "expression": "cast('abc' as bytea)", + "type": "bytea", + }, + ], + }, + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_query(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" ORDER BY orderkey LIMIT 1', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02 00:00:00.000000", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000", + None, + "616263", + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "object", ### fixme this should be float64 + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", + "test_null_time": "datetime64[ns]", + "bytea_column": "object", + } + + +async def test_query_with_limit(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + params={"limit": 1}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders"', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + response = await client.post( + url=f"{base_url}/query", + params={"limit": 1}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 10', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + +async def test_query_without_manifest(client): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_without_sql(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={"connectionInfo": connection_info, "manifestStr": manifest_str}, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_without_connection_info(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_with_dry_run(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" LIMIT 1', + }, + ) + assert response.status_code == 204 + + +async def test_query_with_dry_run_and_invalid_sql(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM X", + }, + ) + assert response.status_code == 422 + assert response.text is not None + + +async def test_validate_with_unknown_rule(client, manifest_str): + response = await client.post( + url=f"{base_url}/validate/unknown_rule", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "Orders", "columnName": "orderkey"}, + }, + ) + assert response.status_code == 404 + assert ( + response.text == f"The rule `unknown_rule` is not in the rules, rules: {rules}" + ) + + +async def test_validate_rule_column_is_valid(client, manifest_str): + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "Orders", "columnName": "orderkey"}, + }, + ) + assert response.status_code == 204 + + +async def test_validate_rule_column_is_valid_with_invalid_parameters( + client, manifest_str +): + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "X", "columnName": "orderkey"}, + }, + ) + assert response.status_code == 422 + + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "Orders", "columnName": "X"}, + }, + ) + assert response.status_code == 422 + + +async def test_validate_rule_column_is_valid_without_parameters(client, manifest_str): + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={"connectionInfo": connection_info, "manifestStr": manifest_str}, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "parameters"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_validate_rule_column_is_valid_without_one_parameter( + client, manifest_str +): + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"modelName": "Orders"}, + }, + ) + assert response.status_code == 422 + assert response.text == "Missing required parameter: `columnName`" + + response = await client.post( + url=f"{base_url}/validate/column_is_valid", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "parameters": {"columnName": "orderkey"}, + }, + ) + assert response.status_code == 422 + assert response.text == "Missing required parameter: `modelName`" + + +async def test_metadata_list_tables(client): + response = await client.post( + url=f"{base_url}/metadata/tables", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + tables = response.json() + assert len(tables) >= 1 + + # Check if our test table exists + test_table = next(filter(lambda t: "orders" in t["name"].lower(), tables), None) + if test_table: + assert test_table["name"] is not None + assert test_table["properties"] is not None + assert test_table["properties"]["schema"] == connection_info["schema_name"] + assert len(test_table["columns"]) > 0 + + # Check column structure + column = test_table["columns"][0] + assert column["name"] is not None + assert column["type"] is not None + assert "notNull" in column + + +async def test_metadata_list_constraints(client): + response = await client.post( + url=f"{base_url}/metadata/constraints", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + + result = response.json() + # Athena doesn't support foreign key constraints, so should return empty list + assert len(result) == 0 + + +async def test_metadata_db_version(client): + response = await client.post( + url=f"{base_url}/metadata/version", + json={"connectionInfo": connection_info}, + ) + assert response.status_code == 200 + assert response.text is not None + # Should return the AWS Athena version string + assert "AWS Athena" in response.text From 2cbed79a97f176e36bd48e8132ab8eef119f8538 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Mon, 2 Jun 2025 15:45:47 +0800 Subject: [PATCH 2/9] exclude athena in all test --- .github/workflows/ibis-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ibis-ci.yml b/.github/workflows/ibis-ci.yml index 34b8fe307..76f46ae09 100644 --- a/.github/workflows/ibis-ci.yml +++ b/.github/workflows/ibis-ci.yml @@ -74,7 +74,7 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AWS_REGION: ${{ secrets.AWS_REGION }} AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }} - run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file and not gcs_file" + run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file and not gcs_file and not athena" - name: Test bigquery if need if: contains(github.event.pull_request.labels.*.name, 'bigquery') env: From a2effbdd8b8f9faaa2d216166710f1a0b52caaa5 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Mon, 2 Jun 2025 17:45:24 +0800 Subject: [PATCH 3/9] v3 test --- .../resources/function_list/athena.csv | 10 + .../tests/routers/v2/connector/test_athena.py | 1 + .../routers/v3/connector/athena/__init__.py | 0 .../routers/v3/connector/athena/conftest.py | 39 ++++ .../v3/connector/athena/test_functions.py | 107 +++++++++ .../routers/v3/connector/athena/test_query.py | 213 ++++++++++++++++++ 6 files changed, 370 insertions(+) create mode 100644 ibis-server/resources/function_list/athena.csv create mode 100644 ibis-server/tests/routers/v3/connector/athena/__init__.py create mode 100644 ibis-server/tests/routers/v3/connector/athena/conftest.py create mode 100644 ibis-server/tests/routers/v3/connector/athena/test_functions.py create mode 100644 ibis-server/tests/routers/v3/connector/athena/test_query.py diff --git a/ibis-server/resources/function_list/athena.csv b/ibis-server/resources/function_list/athena.csv new file mode 100644 index 000000000..cbac4f1a5 --- /dev/null +++ b/ibis-server/resources/function_list/athena.csv @@ -0,0 +1,10 @@ +function_type,name,return_type,param_names,param_types,description +scalar,from_base64,varbinary,,varchar,Converts base64 to binary +scalar,is_finite,boolean,,double or decimal,Tests if value is finite +scalar,is_infinite,boolean,,double or decimal,Tests if value is infinite +scalar,is_nan,boolean,,double or decimal,Tests if value is NaN +scalar,to_base64,varchar,,varbinary,Converts binary to base64 +scalar,try,same_as_input,,any,Returns null if evaluation fails +scalar,url_decode,varchar,,varchar,Decodes URL encoded string +scalar,url_encode,varchar,,varchar,URL encodes string +scalar,word_stem,varchar,,varchar,Returns word stem (English only) diff --git a/ibis-server/tests/routers/v2/connector/test_athena.py b/ibis-server/tests/routers/v2/connector/test_athena.py index cabdaa203..b964f4738 100644 --- a/ibis-server/tests/routers/v2/connector/test_athena.py +++ b/ibis-server/tests/routers/v2/connector/test_athena.py @@ -20,6 +20,7 @@ manifest = { "catalog": "my_catalog", "schema": "my_schema", + "dataSource": "athena", "models": [ { "name": "Orders", diff --git a/ibis-server/tests/routers/v3/connector/athena/__init__.py b/ibis-server/tests/routers/v3/connector/athena/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ibis-server/tests/routers/v3/connector/athena/conftest.py b/ibis-server/tests/routers/v3/connector/athena/conftest.py new file mode 100644 index 000000000..eeb5f7db1 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/athena/conftest.py @@ -0,0 +1,39 @@ +import os +import pathlib + +import pytest + +from app.config import get_config +from tests.conftest import file_path + +pytestmark = pytest.mark.athena + +base_url = "/v3/connector/athena" + +function_list_path = file_path("../resources/function_list") + + +def pytest_collection_modifyitems(items): + current_file_dir = pathlib.Path(__file__).resolve().parent + for item in items: + if pathlib.Path(item.fspath).is_relative_to(current_file_dir): + item.add_marker(pytestmark) + + +@pytest.fixture(scope="session") +def connection_info(): + return { + "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), + "aws_access_key_id": os.getenv("TEST_ATHENA_AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": os.getenv("TEST_ATHENA_AWS_SECRET_ACCESS_KEY"), + "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "schema_name": "test", + } + + +@pytest.fixture(autouse=True) +def set_remote_function_list_path(): + config = get_config() + config.set_remote_function_list_path(function_list_path) + yield + config.set_remote_function_list_path(None) diff --git a/ibis-server/tests/routers/v3/connector/athena/test_functions.py b/ibis-server/tests/routers/v3/connector/athena/test_functions.py new file mode 100644 index 000000000..5e3b6dcc6 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/athena/test_functions.py @@ -0,0 +1,107 @@ +import base64 + +import orjson +import pytest + +from app.config import get_config +from tests.conftest import DATAFUSION_FUNCTION_COUNT, file_path +from tests.routers.v3.connector.athena.conftest import base_url + +manifest = { + "catalog": "my_catalog", + "schema": "my_schema", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "test", + "table": "orders", + }, + "columns": [ + {"name": "o_orderkey", "type": "integer"}, + ], + }, + ], +} + +function_list_path = file_path("../resources/function_list") + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +@pytest.fixture(autouse=True) +def set_remote_function_list_path(): + config = get_config() + config.set_remote_function_list_path(function_list_path) + yield + config.set_remote_function_list_path(None) + + +async def test_function_list(client): + config = get_config() + + config.set_remote_function_list_path(None) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + + config.set_remote_function_list_path(function_list_path) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + 9 + the_func = next(filter(lambda x: x["name"] == "to_base64", result)) + assert the_func == { + "name": "to_base64", + "description": "Converts binary to base64", + "function_type": "scalar", + "param_names": None, + "param_types": None, + "return_type": None, + } + + config.set_remote_function_list_path(None) + response = await client.get(url=f"{base_url}/functions") + assert response.status_code == 200 + result = response.json() + assert len(result) == DATAFUSION_FUNCTION_COUNT + + +async def test_scalar_function(client, manifest_str: str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT ABS(-1) AS col", + }, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[1]], + "dtypes": {"col": "int32"}, + } + + +async def test_aggregate_function(client, manifest_str: str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT COUNT(*) AS col FROM (SELECT 1) AS temp_table", + }, + ) + assert response.status_code == 200 + result = response.json() + assert result == { + "columns": ["col"], + "data": [[1]], + "dtypes": {"col": "int64"}, + } diff --git a/ibis-server/tests/routers/v3/connector/athena/test_query.py b/ibis-server/tests/routers/v3/connector/athena/test_query.py new file mode 100644 index 000000000..c8b980248 --- /dev/null +++ b/ibis-server/tests/routers/v3/connector/athena/test_query.py @@ -0,0 +1,213 @@ +import base64 + +import orjson +import pytest + +from tests.routers.v3.connector.athena.conftest import base_url + +manifest = { + "catalog": "wren", + "schema": "public", + "models": [ + { + "name": "orders", + "tableReference": { + "schema": "test", + "table": "orders", + }, + "columns": [ + {"name": "orderkey", "expression": "o_orderkey", "type": "integer"}, + {"name": "custkey", "expression": "o_custkey", "type": "integer"}, + { + "name": "orderstatus", + "expression": "o_orderstatus", + "type": "varchar", + }, + { + "name": "totalprice", + "expression": "o_totalprice", + "type": "float", + }, + {"name": "orderdate", "expression": "o_orderdate", "type": "date"}, + { + "name": "order_cust_key", + "expression": "concat(cast(o_orderkey as varchar), '_', cast(o_custkey as varchar))", + "type": "varchar", + }, + { + "name": "timestamp", + "expression": "TIMESTAMP '2024-01-01 23:59:59'", + "type": "timestamp", + }, + { + "name": "timestamptz", + "expression": "TIMESTAMP '2024-01-01 23:59:59 UTC'", + "type": "timestamp", + }, + { + "name": "test_null_time", + "expression": "cast(NULL as timestamp)", + "type": "timestamp", + }, + ], + }, + ], +} + + +@pytest.fixture(scope="module") +def manifest_str(): + return base64.b64encode(orjson.dumps(manifest)).decode("utf-8") + + +async def test_query(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02 00:00:00.000000", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000", + None, + ] + + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "object", ### fixme this should be float64 + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", + "test_null_time": "datetime64[ns]", + } + + +async def test_query_with_limit(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + params={"limit": 1}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + response = await client.post( + url=f"{base_url}/query", + params={"limit": 1}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 10", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["data"]) == 1 + + +async def test_query_with_invalid_manifest_str(client, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "manifestStr": "xxx", + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 422 + + +async def test_query_without_manifest(client, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": connection_info, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "manifestStr"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_without_sql(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + json={"connectionInfo": connection_info, "manifestStr": manifest_str}, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "sql"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_without_connection_info(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={ + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 422 + result = response.json() + assert result["detail"][0] is not None + assert result["detail"][0]["type"] == "missing" + assert result["detail"][0]["loc"] == ["body", "connectionInfo"] + assert result["detail"][0]["msg"] == "Field required" + + +async def test_query_with_dry_run(client, manifest_str, connection_info): + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 204 + + +async def test_query_with_dry_run_and_invalid_sql( + client, manifest_str, connection_info +): + response = await client.post( + url=f"{base_url}/query", + params={"dryRun": True}, + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM X", + }, + ) + assert response.status_code == 422 + assert response.text is not None From d321e3cab82b4b42fe839621c9c9fc77da258431 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Mon, 2 Jun 2025 17:45:56 +0800 Subject: [PATCH 4/9] wren core serd --- wren-core-base/manifest-macro/src/lib.rs | 2 ++ wren-core-base/src/mdl/manifest.rs | 1 + 2 files changed, 3 insertions(+) diff --git a/wren-core-base/manifest-macro/src/lib.rs b/wren-core-base/manifest-macro/src/lib.rs index ba48d6823..6e1cff5b0 100644 --- a/wren-core-base/manifest-macro/src/lib.rs +++ b/wren-core-base/manifest-macro/src/lib.rs @@ -104,6 +104,8 @@ pub fn data_source(python_binding: proc_macro::TokenStream) -> proc_macro::Token MinioFile, #[serde(alias = "oracle")] Oracle, + #[serde(alias = "athena")] + Athena, } }; proc_macro::TokenStream::from(expanded) diff --git a/wren-core-base/src/mdl/manifest.rs b/wren-core-base/src/mdl/manifest.rs index 6bfb8b2a1..d06d818bf 100644 --- a/wren-core-base/src/mdl/manifest.rs +++ b/wren-core-base/src/mdl/manifest.rs @@ -115,6 +115,7 @@ impl Display for DataSource { DataSource::GcsFile => write!(f, "GCS_FILE"), DataSource::MinioFile => write!(f, "MINIO_FILE"), DataSource::Oracle => write!(f, "ORACLE"), + DataSource::Athena => write!(f, "ATHENA"), } } } From 7dd1b5b1d6c9be9afabcd945d7711b98c0cc8666 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Mon, 2 Jun 2025 18:05:09 +0800 Subject: [PATCH 5/9] staging dir --- .github/workflows/ibis-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ibis-ci.yml b/.github/workflows/ibis-ci.yml index 76f46ae09..cc6916a28 100644 --- a/.github/workflows/ibis-ci.yml +++ b/.github/workflows/ibis-ci.yml @@ -94,7 +94,7 @@ jobs: if: contains(github.event.pull_request.labels.*.name, 'athena') env: WREN_ENGINE_ENDPOINT: http://localhost:8080 - TEST_ATHENA_S3_STAGING_DIR: ${{ secrets.TEST_ATHENA_S3_STAGING_DIR }} + TEST_ATHENA_S3_STAGING_DIR: s3://wren-ibis-athena-dev/results/ TEST_ATHENA_AWS_ACCESS_KEY_ID: ${{ secrets.TEST_ATHENA_AWS_ACCESS_KEY_ID }} TEST_ATHENA_AWS_SECRET_ACCESS_KEY: ${{ secrets.TEST_ATHENA_AWS_SECRET_ACCESS_KEY }} run: just test athena From 969e7e90ec8a284871a7ac9231a65673ae24027f Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 4 Jun 2025 14:50:41 +0800 Subject: [PATCH 6/9] add glue test case --- .../tests/routers/v2/connector/test_athena.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/ibis-server/tests/routers/v2/connector/test_athena.py b/ibis-server/tests/routers/v2/connector/test_athena.py index b964f4738..f0267d7be 100644 --- a/ibis-server/tests/routers/v2/connector/test_athena.py +++ b/ibis-server/tests/routers/v2/connector/test_athena.py @@ -17,6 +17,16 @@ "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), "schema_name": "test", } + +# Manifest for the database create from glue +glue_connection_info = { + "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), + "aws_access_key_id": os.getenv("TEST_ATHENA_AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": os.getenv("TEST_ATHENA_AWS_SECRET_ACCESS_KEY"), + "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "schema_name": "wren-engine-glue-test", +} + manifest = { "catalog": "my_catalog", "schema": "my_schema", @@ -115,6 +125,46 @@ async def test_query(client, manifest_str): } +async def test_query_glue_database(client, manifest_str): + response = await client.post( + url=f"{base_url}/query", + json={ + "connectionInfo": glue_connection_info, # Use the Glue connection info + "manifestStr": manifest_str, + "sql": 'SELECT * FROM "Orders" ORDER BY orderkey LIMIT 1', + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02 00:00:00.000000", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000", + None, + "616263", + ] + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "object", + "totalprice": "object", ### fixme this should be float64 + "orderdate": "object", + "order_cust_key": "object", + "timestamp": "object", + "timestamptz": "object", + "test_null_time": "datetime64[ns]", + "bytea_column": "object", + } + + async def test_query_with_limit(client, manifest_str): response = await client.post( url=f"{base_url}/query", From b47eaba04d6d081c40cd306eb5eca9d84b91b2a6 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 4 Jun 2025 14:53:09 +0800 Subject: [PATCH 7/9] remove unrelated link --- ibis-server/app/model/metadata/athena.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ibis-server/app/model/metadata/athena.py b/ibis-server/app/model/metadata/athena.py index b342e831c..7a6dd8ebd 100644 --- a/ibis-server/app/model/metadata/athena.py +++ b/ibis-server/app/model/metadata/athena.py @@ -112,7 +112,6 @@ def _transform_column_type(self, data_type): "longtext": RustWrenEngineColumnType.TEXT, "enum": RustWrenEngineColumnType.VARCHAR, "set": RustWrenEngineColumnType.VARCHAR, - # Numeric Types(https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html) "bit": RustWrenEngineColumnType.TINYINT, "tinyint": RustWrenEngineColumnType.TINYINT, "smallint": RustWrenEngineColumnType.SMALLINT, @@ -128,7 +127,6 @@ def _transform_column_type(self, data_type): "double": RustWrenEngineColumnType.DOUBLE, "decimal": RustWrenEngineColumnType.DECIMAL, "numeric": RustWrenEngineColumnType.NUMERIC, - # Date and Time Types(https://dev.mysql.com/doc/refman/8.4/en/date-and-time-types.html) "date": RustWrenEngineColumnType.DATE, "datetime": RustWrenEngineColumnType.TIMESTAMP, "timestamp": RustWrenEngineColumnType.TIMESTAMPTZ, From efea7e2f51c2a34818120e97da68da1625464eb0 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 4 Jun 2025 16:20:17 +0800 Subject: [PATCH 8/9] closing athena cursor --- ibis-server/app/model/metadata/athena.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/ibis-server/app/model/metadata/athena.py b/ibis-server/app/model/metadata/athena.py index 7a6dd8ebd..ccf727446 100644 --- a/ibis-server/app/model/metadata/athena.py +++ b/ibis-server/app/model/metadata/athena.py @@ -1,4 +1,5 @@ import re +from contextlib import closing import pandas as pd @@ -44,10 +45,10 @@ def get_table_list(self) -> list[Table]: # We need to use raw_sql here because using the sql method causes Ibis to *create view* first, # which does not work with information_schema queries. - cursor = self.connection.raw_sql(sql) - response = pd.DataFrame( - cursor.fetchall(), columns=[col[0] for col in cursor.description] - ).to_dict(orient="records") + with closing(self.connection.raw_sql(sql)) as cursor: + response = pd.DataFrame( + cursor.fetchall(), columns=[col[0] for col in cursor.description] + ).to_dict(orient="records") def get_column(row) -> Column: return Column( From 02f1f72ae495ed3493562c803162ddcca603a15d Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Wed, 4 Jun 2025 16:25:25 +0800 Subject: [PATCH 9/9] closing other datasource cursor --- ibis-server/app/model/connector.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ibis-server/app/model/connector.py b/ibis-server/app/model/connector.py index 46186b955..a387a5b99 100644 --- a/ibis-server/app/model/connector.py +++ b/ibis-server/app/model/connector.py @@ -1,5 +1,6 @@ import base64 import importlib +from contextlib import closing from functools import cache from json import loads from typing import Any @@ -97,7 +98,7 @@ def dry_run(self, sql: str) -> None: def _describe_sql_for_error_message(self, sql: str) -> str: tsql = sge.convert(sql).sql("mssql") describe_sql = f"SELECT error_message FROM sys.dm_exec_describe_first_result_set({tsql}, NULL, 0)" - with self.connection.raw_sql(describe_sql) as cur: + with closing(self.connection.raw_sql(describe_sql)) as cur: rows = cur.fetchall() if rows is None or len(rows) == 0: return "Unknown reason" @@ -214,8 +215,8 @@ def dry_run(self, sql: str) -> None: @cache def _get_pg_type_names(connection: BaseBackend) -> dict[int, str]: - cur = connection.raw_sql("SELECT oid, typname FROM pg_type") - return dict(cur.fetchall()) + with closing(connection.raw_sql("SELECT oid, typname FROM pg_type")) as cur: + return dict(cur.fetchall()) class QueryDryRunError(UnprocessableEntityError):