Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ibis-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,11 @@ jobs:
SNOWFLAKE_PASSWORD: ${{ secrets.SNOWFLAKE_PASSWORD }}
SNOWFLAKE_ACCOUNT: ${{ secrets.SNOWFLAKE_ACCOUNT }}
run: just test snowflake
- name: Test athena if need
if: contains(github.event.pull_request.labels.*.name, 'athena')
env:
WREN_ENGINE_ENDPOINT: http://localhost:8080
TEST_ATHENA_S3_STAGING_DIR: ${{ secrets.TEST_ATHENA_S3_STAGING_DIR }}
TEST_ATHENA_AWS_ACCESS_KEY_ID: ${{ secrets.TEST_ATHENA_AWS_ACCESS_KEY_ID }}
TEST_ATHENA_AWS_SECRET_ACCESS_KEY: ${{ secrets.TEST_ATHENA_AWS_SECRET_ACCESS_KEY }}
run: just test athena
28 changes: 27 additions & 1 deletion ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class QueryBigQueryDTO(QueryDTO):
connection_info: BigQueryConnectionInfo = connection_info_field


class QueryAthenaDTO(QueryDTO):
connection_info: AthenaConnectionInfo = connection_info_field


class QueryCannerDTO(QueryDTO):
connection_info: ConnectionUrl | CannerConnectionInfo = connection_info_field

Expand Down Expand Up @@ -98,6 +102,27 @@ class BigQueryConnectionInfo(BaseConnectionInfo):
)


class AthenaConnectionInfo(BaseConnectionInfo):
s3_staging_dir: SecretStr = Field(
description="S3 staging directory for Athena queries",
examples=["s3://my-bucket/athena-staging/"],
)
aws_access_key_id: SecretStr = Field(
description="AWS access key ID", examples=["AKIA..."]
)
aws_secret_access_key: SecretStr = Field(
description="AWS secret access key", examples=["my-secret-key"]
)
region_name: SecretStr = Field(
description="AWS region for Athena", examples=["us-west-2", "us-east-1"]
)
schema_name: SecretStr = Field(
alias="schema_name",
description="The database name in Athena",
examples=["default"],
)


class CannerConnectionInfo(BaseConnectionInfo):
host: SecretStr = Field(
description="the hostname of your database", examples=["localhost"]
Expand Down Expand Up @@ -339,7 +364,8 @@ class GcsFileConnectionInfo(BaseConnectionInfo):


ConnectionInfo = (
BigQueryConnectionInfo
AthenaConnectionInfo
| BigQueryConnectionInfo
| CannerConnectionInfo
| ConnectionUrl
| MSSqlConnectionInfo
Expand Down
14 changes: 14 additions & 0 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ibis import BaseBackend

from app.model import (
AthenaConnectionInfo,
BigQueryConnectionInfo,
CannerConnectionInfo,
ClickHouseConnectionInfo,
Expand All @@ -19,6 +20,7 @@
MySqlConnectionInfo,
OracleConnectionInfo,
PostgresConnectionInfo,
QueryAthenaDTO,
QueryBigQueryDTO,
QueryCannerDTO,
QueryClickHouseDTO,
Expand All @@ -40,6 +42,7 @@


class DataSource(StrEnum):
athena = auto()
bigquery = auto()
canner = auto()
clickhouse = auto()
Expand Down Expand Up @@ -68,6 +71,7 @@ def get_dto_type(self):


class DataSourceExtension(Enum):
athena = QueryAthenaDTO
bigquery = QueryBigQueryDTO
canner = QueryCannerDTO
clickhouse = QueryClickHouseDTO
Expand Down Expand Up @@ -97,6 +101,16 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend:
except KeyError:
raise NotImplementedError(f"Unsupported data source: {self}")

@staticmethod
def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend:
return ibis.athena.connect(
s3_staging_dir=info.s3_staging_dir.get_secret_value(),
aws_access_key_id=info.aws_access_key_id.get_secret_value(),
aws_secret_access_key=info.aws_secret_access_key.get_secret_value(),
region_name=info.region_name.get_secret_value(),
schema_name=info.schema_name.get_secret_value(),
)

@staticmethod
def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:
credits_json = loads(
Expand Down
139 changes: 139 additions & 0 deletions ibis-server/app/model/metadata/athena.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import re

import pandas as pd

from app.model import AthenaConnectionInfo
from app.model.data_source import DataSource
from app.model.metadata.dto import (
Column,
Constraint,
RustWrenEngineColumnType,
Table,
TableProperties,
)
from app.model.metadata.metadata import Metadata


class AthenaMetadata(Metadata):
def __init__(self, connection_info: AthenaConnectionInfo):
super().__init__(connection_info)
self.connection = DataSource.athena.get_connection(connection_info)

def get_table_list(self) -> list[Table]:
schema_name = self.connection_info.schema_name.get_secret_value()

sql = f"""
SELECT
t.table_catalog,
t.table_schema,
t.table_name,
c.column_name,
c.ordinal_position,
c.is_nullable,
c.data_type
FROM
information_schema.tables AS t
JOIN
information_schema.columns AS c
ON t.table_catalog = c.table_catalog
AND t.table_schema = c.table_schema
AND t.table_name = c.table_name
WHERE t.table_schema = '{schema_name}'
ORDER BY t.table_name
"""
Comment thread
douenergy marked this conversation as resolved.

# We need to use raw_sql here because using the sql method causes Ibis to *create view* first,
# which does not work with information_schema queries.
cursor = self.connection.raw_sql(sql)
response = pd.DataFrame(
cursor.fetchall(), columns=[col[0] for col in cursor.description]
).to_dict(orient="records")
Comment thread
douenergy marked this conversation as resolved.
Outdated

def get_column(row) -> Column:
return Column(
name=row["column_name"],
type=self._transform_column_type(row["data_type"]),
notNull=row["is_nullable"].lower() == "no",
description="", # Athena doesn't provide column descriptions in information_schema
properties=None,
)

def get_table(row) -> Table:
return Table(
name=self._format_athena_compact_table_name(
row["table_schema"], row["table_name"]
),
description="", # Athena doesn't provide table descriptions in information_schema
columns=[],
properties=TableProperties(
schema=row["table_schema"],
catalog=row["table_catalog"],
table=row["table_name"],
),
primaryKey="",
)

unique_tables = {}

for column_metadata in response:
# generate unique table name
table_name = self._format_athena_compact_table_name(
column_metadata["table_schema"], column_metadata["table_name"]
)
# init table if not exists
if table_name not in unique_tables:
unique_tables[table_name] = get_table(column_metadata)

current_table = unique_tables[table_name]
# add column to table
current_table.columns.append(get_column(column_metadata))

return list(unique_tables.values())

def get_constraints(self) -> list[Constraint]:
# Athena doesn't support foreign key constraints
return []

def get_version(self) -> str:
return "AWS Athena - Follow AWS service versioning"

def _format_athena_compact_table_name(self, schema: str, table: str) -> str:
return f"{schema}.{table}"

def _transform_column_type(self, data_type):
data_type = re.sub(r"\(.*\)", "", data_type).strip()
switcher = {
# String Types (ignore Binary and Spatial Types for now)
"char": RustWrenEngineColumnType.CHAR,
"varchar": RustWrenEngineColumnType.VARCHAR,
"tinytext": RustWrenEngineColumnType.TEXT,
"text": RustWrenEngineColumnType.TEXT,
"mediumtext": RustWrenEngineColumnType.TEXT,
"longtext": RustWrenEngineColumnType.TEXT,
"enum": RustWrenEngineColumnType.VARCHAR,
"set": RustWrenEngineColumnType.VARCHAR,
# Numeric Types(https://dev.mysql.com/doc/refman/8.4/en/numeric-types.html)
Comment thread
douenergy marked this conversation as resolved.
Outdated
"bit": RustWrenEngineColumnType.TINYINT,
"tinyint": RustWrenEngineColumnType.TINYINT,
"smallint": RustWrenEngineColumnType.SMALLINT,
"mediumint": RustWrenEngineColumnType.INTEGER,
"int": RustWrenEngineColumnType.INTEGER,
"integer": RustWrenEngineColumnType.INTEGER,
"bigint": RustWrenEngineColumnType.BIGINT,
# boolean
"bool": RustWrenEngineColumnType.BOOL,
"boolean": RustWrenEngineColumnType.BOOL,
# Decimal
"float": RustWrenEngineColumnType.FLOAT4,
"double": RustWrenEngineColumnType.DOUBLE,
"decimal": RustWrenEngineColumnType.DECIMAL,
"numeric": RustWrenEngineColumnType.NUMERIC,
# Date and Time Types(https://dev.mysql.com/doc/refman/8.4/en/date-and-time-types.html)
"date": RustWrenEngineColumnType.DATE,
"datetime": RustWrenEngineColumnType.TIMESTAMP,
"timestamp": RustWrenEngineColumnType.TIMESTAMPTZ,
# JSON Type
"json": RustWrenEngineColumnType.JSON,
}

return switcher.get(data_type.lower(), RustWrenEngineColumnType.UNKNOWN)
Comment thread
douenergy marked this conversation as resolved.
2 changes: 2 additions & 0 deletions ibis-server/app/model/metadata/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from app.model.data_source import DataSource
from app.model.metadata.athena import AthenaMetadata
from app.model.metadata.bigquery import BigQueryMetadata
from app.model.metadata.canner import CannerMetadata
from app.model.metadata.clickhouse import ClickHouseMetadata
Expand All @@ -17,6 +18,7 @@
from app.model.metadata.trino import TrinoMetadata

mapping = {
DataSource.athena: AthenaMetadata,
DataSource.bigquery: BigQueryMetadata,
DataSource.canner: CannerMetadata,
DataSource.clickhouse: ClickHouseMetadata,
Expand Down
Loading
Loading