Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions ibis-server/app/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,55 @@ class AthenaConnectionInfo(BaseConnectionInfo):
description="S3 staging directory for Athena queries",
examples=["s3://my-bucket/athena-staging/"],
)
aws_access_key_id: SecretStr = Field(
description="AWS access key ID", examples=["AKIA..."]

# ── Standard AWS credential chain (optional) ─────────────
aws_access_key_id: SecretStr | None = Field(
description="AWS access key ID. Optional if using IAM role, web identity token, or default credential chain.",
examples=["AKIA..."],
default=None,
)
aws_secret_access_key: SecretStr | None = Field(
description="AWS secret access key. Optional if using IAM role, web identity token, or default credential chain.",
examples=["my-secret-key"],
default=None,
)
aws_session_token: SecretStr | None = Field(
description="AWS session token (used for temporary credentials)",
examples=["IQoJb3JpZ2luX2VjEJz//////////wEaCXVzLWVhc3QtMSJHMEUCIQD..."],
default=None,
)

# ── Web identity federation (OIDC/JWT-based) ─────────────
web_identity_token: SecretStr | None = Field(
description=(
"OIDC web identity token (JWT) used for AssumeRoleWithWebIdentity authentication. "
"If provided, PyAthena will call STS to exchange it for temporary credentials."
),
examples=["eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9..."],
default=None,
)
aws_secret_access_key: SecretStr = Field(
description="AWS secret access key", examples=["my-secret-key"]
role_arn: SecretStr | None = Field(
description="The ARN of the role to assume with the web identity token.",
examples=["arn:aws:iam::123456789012:role/YourAthenaRole"],
default=None,
)
role_session_name: SecretStr | None = Field(
description="The session name when assuming a role (optional).",
examples=["PyAthena-session"],
default=None,
)

# ── Regional and database settings ───────────────────────
region_name: SecretStr = Field(
description="AWS region for Athena", examples=["us-west-2", "us-east-1"]
description="AWS region for Athena. Optional; will use default region if not provided.",
examples=["us-west-2", "us-east-1"],
default=None,
)
schema_name: SecretStr = Field(
schema_name: SecretStr | None = Field(
alias="schema_name",
description="The database name in Athena",
description="The database name in Athena. Defaults to 'default'.",
examples=["default"],
default=SecretStr("default"),
)


Expand Down
57 changes: 50 additions & 7 deletions ibis-server/app/model/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any
from urllib.parse import unquote_plus

import boto3
import ibis
from google.cloud import bigquery
from google.oauth2 import service_account
Expand Down Expand Up @@ -254,13 +255,55 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend:

@staticmethod
def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend:
return ibis.athena.connect(
s3_staging_dir=info.s3_staging_dir.get_secret_value(),
aws_access_key_id=info.aws_access_key_id.get_secret_value(),
aws_secret_access_key=info.aws_secret_access_key.get_secret_value(),
region_name=info.region_name.get_secret_value(),
schema_name=info.schema_name.get_secret_value(),
)
kwargs: dict[str, Any] = {
"s3_staging_dir": info.s3_staging_dir.get_secret_value(),
"schema_name": info.schema_name.get_secret_value(),
}

# ── Region ────────────────────────────────────────────────
if info.region_name:
kwargs["region_name"] = info.region_name.get_secret_value()

# ── Web Identity Token flow (Google OIDC → AWS STS) ───
if info.web_identity_token and info.role_arn:
oidc_token = info.web_identity_token.get_secret_value()
role_arn = info.role_arn.get_secret_value()
session_name = (
info.role_session_name.get_secret_value()
if info.role_session_name
else "wren-oidc-session"
)
region = info.region_name.get_secret_value() if info.region_name else None
sts = boto3.client("sts", region_name=region)

resp = sts.assume_role_with_web_identity(
RoleArn=role_arn,
RoleSessionName=session_name,
WebIdentityToken=oidc_token,
)

creds = resp["Credentials"]
kwargs["aws_access_key_id"] = creds["AccessKeyId"]
kwargs["aws_secret_access_key"] = creds["SecretAccessKey"]
kwargs["aws_session_token"] = creds["SessionToken"]

# ── Standard Access/Secret Keys ───────────────────────
elif info.aws_access_key_id and info.aws_secret_access_key:
kwargs["aws_access_key_id"] = info.aws_access_key_id.get_secret_value()
kwargs["aws_secret_access_key"] = (
info.aws_secret_access_key.get_secret_value()
)
if info.aws_session_token:
kwargs["aws_session_token"] = info.aws_session_token.get_secret_value()

# ── 3️⃣ Default AWS credential chain ───────────────────────
# Nothing needed — PyAthena automatically falls back to:
# - Environment variables
# - ~/.aws/credentials
# - IAM Role (EC2, ECS, Lambda)

# Now connect via Ibis wrapper
return ibis.athena.connect(**kwargs)

@staticmethod
def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:
Expand Down
6 changes: 3 additions & 3 deletions ibis-server/tests/routers/v2/connector/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
},
{
"name": "timestamptz",
"expression": "TIMESTAMP '2024-01-01 23:59:59 UTC'",
"expression": "CAST(TIMESTAMP '2024-01-01 23:59:59 UTC' AS timestamp)",
"type": "timestamp",
},
{
Expand Down Expand Up @@ -113,7 +113,7 @@ async def test_query(client, manifest_str):
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(15, 2)",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[us]",
Expand Down Expand Up @@ -153,7 +153,7 @@ async def test_query_glue_database(client, manifest_str):
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(15, 2)",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[us]",
Expand Down
33 changes: 33 additions & 0 deletions ibis-server/tests/routers/v3/connector/athena/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,39 @@ def connection_info():
}


@pytest.fixture(scope="session")
def connection_info_default_credential_chain():
# Use default authentication (e.g., from environment variables, shared config file, or EC2 instance profile)
access_key = os.getenv("AWS_ACCESS_KEY_ID")
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
if not access_key or not secret_key:
pytest.skip(
"Skipping default credential chain test: AWS credentials not set in environment"
)
return {
"s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"),
"region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"),
"schema_name": "test",
}


@pytest.fixture(scope="session")
def connection_info_oidc():
web_identity_token = os.getenv("TEST_ATHENA_WEB_IDENTITY_TOKEN")
role_arn = os.getenv("TEST_ATHENA_ROLE_ARN")

if not web_identity_token or not role_arn:
pytest.skip("Skipping OIDC test: web identity token or role ARN not set")

return {
"s3_staging_dir": os.getenv("TEST_ATHENA_OIDC_S3_STAGING_DIR"),
"region_name": os.getenv("TEST_ATHENA_OIDC_REGION_NAME", "us-west-1"),
"schema_name": "test",
"role_arn": role_arn,
"web_identity_token": web_identity_token,
}


@pytest.fixture(autouse=True)
def set_remote_function_list_path():
config = get_config()
Expand Down
51 changes: 50 additions & 1 deletion ibis-server/tests/routers/v3/connector/athena/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def test_query(client, manifest_str, connection_info):
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(15, 2)",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[us]",
Expand Down Expand Up @@ -211,3 +211,52 @@ async def test_query_with_dry_run_and_invalid_sql(
)
assert response.status_code == 422
assert response.text is not None


@pytest.mark.parametrize(
"conn_fixture",
[
"connection_info",
"connection_info_default_credential_chain",
"connection_info_oidc",
],
)
async def test_query_athena_modes(client, manifest_str, request, conn_fixture):
connection_info = request.getfixturevalue(conn_fixture)

response = await client.post(
url="/v3/connector/athena/query",
json={
"connectionInfo": connection_info,
"manifestStr": manifest_str,
"sql": "SELECT * FROM wren.public.orders LIMIT 1",
},
)
assert response.status_code == 200
result = response.json()
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
assert len(result["data"]) == 1

assert result["data"][0] == [
1,
36901,
"O",
"173665.47",
"1996-01-02",
"1_36901",
"2024-01-01 23:59:59.000000",
"2024-01-01 23:59:59.000000",
None,
]

assert result["dtypes"] == {
"orderkey": "int64",
"custkey": "int64",
"orderstatus": "string",
"totalprice": "decimal128(38, 9)",
"orderdate": "date32[day]",
"order_cust_key": "string",
"timestamp": "timestamp[us]",
"timestamptz": "timestamp[us]",
"test_null_time": "timestamp[us]",
}