Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/ibis-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION: ${{ secrets.AWS_REGION }}
AWS_S3_BUCKET: ${{ secrets.AWS_S3_BUCKET }}
run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file and not gcs_file"
run: poetry run pytest -m "not bigquery and not snowflake and not canner and not s3_file and not gcs_file and not athena"
- name: Test bigquery if need
if: contains(github.event.pull_request.labels.*.name, 'bigquery')
env:
Expand All @@ -90,3 +90,11 @@ jobs:
SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }}
SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }}
run: just test snowflake
- name: Test athena if need
if: contains(github.event.pull_request.labels.*.name, 'athena')
env:
WREN_ENGINE_ENDPOINT: http://localhost:8080
TEST_ATHENA_S3_STAGING_DIR: s3://wren-ibis-athena-dev/results/
TEST_ATHENA_AWS_ACCESS_KEY_ID: ${{ secrets.TEST_ATHENA_AWS_ACCESS_KEY_ID }}
TEST_ATHENA_AWS_SECRET_ACCESS_KEY: ${{ secrets.TEST_ATHENA_AWS_SECRET_ACCESS_KEY }}
run: just test athena
28 changes: 27 additions & 1 deletion ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class QueryBigQueryDTO(QueryDTO):
connection_info: BigQueryConnectionInfo = connection_info_field


class QueryAthenaDTO(QueryDTO):
connection_info: AthenaConnectionInfo = connection_info_field


class QueryCannerDTO(QueryDTO):
connection_info: ConnectionUrl | CannerConnectionInfo = connection_info_field

Expand Down Expand Up @@ -98,6 +102,27 @@ class BigQueryConnectionInfo(BaseConnectionInfo):
)


class AthenaConnectionInfo(BaseConnectionInfo):
s3_staging_dir: SecretStr = Field(
description="S3 staging directory for Athena queries",
examples=["s3://my-bucket/athena-staging/"],
)
aws_access_key_id: SecretStr = Field(
description="AWS access key ID", examples=["AKIA..."]
)
aws_secret_access_key: SecretStr = Field(
description="AWS secret access key", examples=["my-secret-key"]
)
region_name: SecretStr = Field(
description="AWS region for Athena", examples=["us-west-2", "us-east-1"]
)
schema_name: SecretStr = Field(
alias="schema_name",
description="The database name in Athena",
examples=["default"],
)


class CannerConnectionInfo(BaseConnectionInfo):
host: SecretStr = Field(
description="the hostname of your database", examples=["localhost"]
Expand Down Expand Up @@ -339,7 +364,8 @@ class GcsFileConnectionInfo(BaseConnectionInfo):


ConnectionInfo = (
BigQueryConnectionInfo
AthenaConnectionInfo
| BigQueryConnectionInfo
| CannerConnectionInfo
| ConnectionUrl
| MSSqlConnectionInfo
Expand Down
7 changes: 4 additions & 3 deletions ibis-server/app/model/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import importlib
from contextlib import closing
from functools import cache
from json import loads
from typing import Any
Expand Down Expand Up @@ -97,7 +98,7 @@ def dry_run(self, sql: str) -> None:
def _describe_sql_for_error_message(self, sql: str) -> str:
tsql = sge.convert(sql).sql("mssql")
describe_sql = f"SELECT error_message FROM sys.dm_exec_describe_first_result_set({tsql}, NULL, 0)"
with self.connection.raw_sql(describe_sql) as cur:
with closing(self.connection.raw_sql(describe_sql)) as cur:
rows = cur.fetchall()
if rows is None or len(rows) == 0:
return "Unknown reason"
Expand Down Expand Up @@ -214,8 +215,8 @@ def dry_run(self, sql: str) -> None:

@cache
def _get_pg_type_names(connection: BaseBackend) -> dict[int, str]:
cur = connection.raw_sql("SELECT oid, typname FROM pg_type")
return dict(cur.fetchall())
with closing(connection.raw_sql("SELECT oid, typname FROM pg_type")) as cur:
return dict(cur.fetchall())


class QueryDryRunError(UnprocessableEntityError):
Expand Down
14 changes: 14 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ibis import BaseBackend

from app.model import (
AthenaConnectionInfo,
BigQueryConnectionInfo,
CannerConnectionInfo,
ClickHouseConnectionInfo,
Expand All @@ -19,6 +20,7 @@
MySqlConnectionInfo,
OracleConnectionInfo,
PostgresConnectionInfo,
QueryAthenaDTO,
QueryBigQueryDTO,
QueryCannerDTO,
QueryClickHouseDTO,
Expand All @@ -40,6 +42,7 @@


class DataSource(StrEnum):
athena = auto()
bigquery = auto()
canner = auto()
clickhouse = auto()
Expand Down Expand Up @@ -68,6 +71,7 @@ def get_dto_type(self):


class DataSourceExtension(Enum):
athena = QueryAthenaDTO
bigquery = QueryBigQueryDTO
canner = QueryCannerDTO
clickhouse = QueryClickHouseDTO
Expand Down Expand Up @@ -97,6 +101,16 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend:
except KeyError:
raise NotImplementedError(f"Unsupported data source: {self}")

@staticmethod
def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend:
return ibis.athena.connect(
s3_staging_dir=info.s3_staging_dir.get_secret_value(),
aws_access_key_id=info.aws_access_key_id.get_secret_value(),
aws_secret_access_key=info.aws_secret_access_key.get_secret_value(),
region_name=info.region_name.get_secret_value(),
schema_name=info.schema_name.get_secret_value(),
)

@staticmethod
def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:
credits_json = loads(
Expand Down
138 changes: 138 additions & 0 deletions ibis-server/app/model/metadata/athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import re
from contextlib import closing

import pandas as pd

from app.model import AthenaConnectionInfo
from app.model.data_source import DataSource
from app.model.metadata.dto import (
Column,
Constraint,
RustWrenEngineColumnType,
Table,
TableProperties,
)
from app.model.metadata.metadata import Metadata


class AthenaMetadata(Metadata):
def __init__(self, connection_info: AthenaConnectionInfo):
super().__init__(connection_info)
self.connection = DataSource.athena.get_connection(connection_info)

def get_table_list(self) -> list[Table]:
schema_name = self.connection_info.schema_name.get_secret_value()

sql = f"""
SELECT
t.table_catalog,
t.table_schema,
t.table_name,
c.column_name,
c.ordinal_position,
c.is_nullable,
c.data_type
FROM
information_schema.tables AS t
JOIN
information_schema.columns AS c
ON t.table_catalog = c.table_catalog
AND t.table_schema = c.table_schema
AND t.table_name = c.table_name
WHERE t.table_schema = '{schema_name}'
ORDER BY t.table_name
"""

# We need to use raw_sql here because using the sql method causes Ibis to *create view* first,
# which does not work with information_schema queries.
with closing(self.connection.raw_sql(sql)) as cursor:
response = pd.DataFrame(
cursor.fetchall(), columns=[col[0] for col in cursor.description]
).to_dict(orient="records")

def get_column(row) -> Column:
return Column(
name=row["column_name"],
type=self._transform_column_type(row["data_type"]),
notNull=row["is_nullable"].lower() == "no",
description="", # Athena doesn't provide column descriptions in information_schema
properties=None,
)

def get_table(row) -> Table:
return Table(
name=self._format_athena_compact_table_name(
row["table_schema"], row["table_name"]
),
description="", # Athena doesn't provide table descriptions in information_schema
columns=[],
properties=TableProperties(
schema=row["table_schema"],
catalog=row["table_catalog"],
table=row["table_name"],
),
primaryKey="",
)

unique_tables = {}

for column_metadata in response:
# generate unique table name
table_name = self._format_athena_compact_table_name(
column_metadata["table_schema"], column_metadata["table_name"]
)
# init table if not exists
if table_name not in unique_tables:
unique_tables[table_name] = get_table(column_metadata)

current_table = unique_tables[table_name]
# add column to table
current_table.columns.append(get_column(column_metadata))

return list(unique_tables.values())

def get_constraints(self) -> list[Constraint]:
# Athena doesn't support foreign key constraints
return []

def get_version(self) -> str:
return "AWS Athena - Follow AWS service versioning"

def _format_athena_compact_table_name(self, schema: str, table: str) -> str:
return f"{schema}.{table}"

def _transform_column_type(self, data_type):
data_type = re.sub(r"\(.*\)", "", data_type).strip()
switcher = {
# String Types (ignore Binary and Spatial Types for now)
"char": RustWrenEngineColumnType.CHAR,
"varchar": RustWrenEngineColumnType.VARCHAR,
"tinytext": RustWrenEngineColumnType.TEXT,
"text": RustWrenEngineColumnType.TEXT,
"mediumtext": RustWrenEngineColumnType.TEXT,
"longtext": RustWrenEngineColumnType.TEXT,
"enum": RustWrenEngineColumnType.VARCHAR,
"set": RustWrenEngineColumnType.VARCHAR,
"bit": RustWrenEngineColumnType.TINYINT,
"tinyint": RustWrenEngineColumnType.TINYINT,
"smallint": RustWrenEngineColumnType.SMALLINT,
"mediumint": RustWrenEngineColumnType.INTEGER,
"int": RustWrenEngineColumnType.INTEGER,
"integer": RustWrenEngineColumnType.INTEGER,
"bigint": RustWrenEngineColumnType.BIGINT,
# boolean
"bool": RustWrenEngineColumnType.BOOL,
"boolean": RustWrenEngineColumnType.BOOL,
# Decimal
"float": RustWrenEngineColumnType.FLOAT4,
"double": RustWrenEngineColumnType.DOUBLE,
"decimal": RustWrenEngineColumnType.DECIMAL,
"numeric": RustWrenEngineColumnType.NUMERIC,
"date": RustWrenEngineColumnType.DATE,
"datetime": RustWrenEngineColumnType.TIMESTAMP,
"timestamp": RustWrenEngineColumnType.TIMESTAMPTZ,
# JSON Type
"json": RustWrenEngineColumnType.JSON,
}

return switcher.get(data_type.lower(), RustWrenEngineColumnType.UNKNOWN)
2 changes: 2 additions & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from app.model.data_source import DataSource
from app.model.metadata.athena import AthenaMetadata
from app.model.metadata.bigquery import BigQueryMetadata
from app.model.metadata.canner import CannerMetadata
from app.model.metadata.clickhouse import ClickHouseMetadata
Expand All @@ -17,6 +18,7 @@
from app.model.metadata.trino import TrinoMetadata

mapping = {
DataSource.athena: AthenaMetadata,
DataSource.bigquery: BigQueryMetadata,
DataSource.canner: CannerMetadata,
DataSource.clickhouse: ClickHouseMetadata,
Expand Down
Loading
Loading