diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index c8f62d06b..bc900501f 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -118,19 +118,55 @@ class AthenaConnectionInfo(BaseConnectionInfo): description="S3 staging directory for Athena queries", examples=["s3://my-bucket/athena-staging/"], ) - aws_access_key_id: SecretStr = Field( - description="AWS access key ID", examples=["AKIA..."] + + # ── Standard AWS credential chain (optional) ───────────── + aws_access_key_id: SecretStr | None = Field( + description="AWS access key ID. Optional if using IAM role, web identity token, or default credential chain.", + examples=["AKIA..."], + default=None, + ) + aws_secret_access_key: SecretStr | None = Field( + description="AWS secret access key. Optional if using IAM role, web identity token, or default credential chain.", + examples=["my-secret-key"], + default=None, + ) + aws_session_token: SecretStr | None = Field( + description="AWS session token (used for temporary credentials)", + examples=["IQoJb3JpZ2luX2VjEJz//////////wEaCXVzLWVhc3QtMSJHMEUCIQD..."], + default=None, + ) + + # ── Web identity federation (OIDC/JWT-based) ───────────── + web_identity_token: SecretStr | None = Field( + description=( + "OIDC web identity token (JWT) used for AssumeRoleWithWebIdentity authentication. " + "If provided, PyAthena will call STS to exchange it for temporary credentials." + ), + examples=["eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9..."], + default=None, ) - aws_secret_access_key: SecretStr = Field( - description="AWS secret access key", examples=["my-secret-key"] + role_arn: SecretStr | None = Field( + description="The ARN of the role to assume with the web identity token.", + examples=["arn:aws:iam::123456789012:role/YourAthenaRole"], + default=None, ) + role_session_name: SecretStr | None = Field( + description="The session name when assuming a role (optional).", + examples=["PyAthena-session"], + default=None, + ) + + # ── Regional and database settings ─────────────────────── region_name: SecretStr = Field( - description="AWS region for Athena", examples=["us-west-2", "us-east-1"] + description="AWS region for Athena. Optional; will use default region if not provided.", + examples=["us-west-2", "us-east-1"], + default=None, ) - schema_name: SecretStr = Field( + schema_name: SecretStr | None = Field( alias="schema_name", - description="The database name in Athena", + description="The database name in Athena. Defaults to 'default'.", examples=["default"], + default=SecretStr("default"), ) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index e26c20cf9..31e068e18 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -8,6 +8,7 @@ from typing import Any from urllib.parse import unquote_plus +import boto3 import ibis from google.cloud import bigquery from google.oauth2 import service_account @@ -254,13 +255,55 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend: @staticmethod def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: - return ibis.athena.connect( - s3_staging_dir=info.s3_staging_dir.get_secret_value(), - aws_access_key_id=info.aws_access_key_id.get_secret_value(), - aws_secret_access_key=info.aws_secret_access_key.get_secret_value(), - region_name=info.region_name.get_secret_value(), - schema_name=info.schema_name.get_secret_value(), - ) + kwargs: dict[str, Any] = { + "s3_staging_dir": info.s3_staging_dir.get_secret_value(), + "schema_name": info.schema_name.get_secret_value(), + } + + # ── Region ──────────────────────────────────────────────── + if info.region_name: + kwargs["region_name"] = info.region_name.get_secret_value() + + # ── Web Identity Token flow (Google OIDC → AWS STS) ─── + if info.web_identity_token and info.role_arn: + oidc_token = info.web_identity_token.get_secret_value() + role_arn = info.role_arn.get_secret_value() + session_name = ( + info.role_session_name.get_secret_value() + if info.role_session_name + else "wren-oidc-session" + ) + region = info.region_name.get_secret_value() if info.region_name else None + sts = boto3.client("sts", region_name=region) + + resp = sts.assume_role_with_web_identity( + RoleArn=role_arn, + RoleSessionName=session_name, + WebIdentityToken=oidc_token, + ) + + creds = resp["Credentials"] + kwargs["aws_access_key_id"] = creds["AccessKeyId"] + kwargs["aws_secret_access_key"] = creds["SecretAccessKey"] + kwargs["aws_session_token"] = creds["SessionToken"] + + # ── Standard Access/Secret Keys ─────────────────────── + elif info.aws_access_key_id and info.aws_secret_access_key: + kwargs["aws_access_key_id"] = info.aws_access_key_id.get_secret_value() + kwargs["aws_secret_access_key"] = ( + info.aws_secret_access_key.get_secret_value() + ) + if info.aws_session_token: + kwargs["aws_session_token"] = info.aws_session_token.get_secret_value() + + # ── 3️⃣ Default AWS credential chain ─────────────────────── + # Nothing needed — PyAthena automatically falls back to: + # - Environment variables + # - ~/.aws/credentials + # - IAM Role (EC2, ECS, Lambda) + + # Now connect via Ibis wrapper + return ibis.athena.connect(**kwargs) @staticmethod def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend: diff --git a/ibis-server/tests/routers/v2/connector/test_athena.py b/ibis-server/tests/routers/v2/connector/test_athena.py index 027776a54..7cf6d002b 100644 --- a/ibis-server/tests/routers/v2/connector/test_athena.py +++ b/ibis-server/tests/routers/v2/connector/test_athena.py @@ -59,7 +59,7 @@ }, { "name": "timestamptz", - "expression": "TIMESTAMP '2024-01-01 23:59:59 UTC'", + "expression": "CAST(TIMESTAMP '2024-01-01 23:59:59 UTC' AS timestamp)", "type": "timestamp", }, { @@ -113,7 +113,7 @@ async def test_query(client, manifest_str): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[us]", @@ -153,7 +153,7 @@ async def test_query_glue_database(client, manifest_str): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[us]", diff --git a/ibis-server/tests/routers/v3/connector/athena/conftest.py b/ibis-server/tests/routers/v3/connector/athena/conftest.py index eeb5f7db1..7c4c6febd 100644 --- a/ibis-server/tests/routers/v3/connector/athena/conftest.py +++ b/ibis-server/tests/routers/v3/connector/athena/conftest.py @@ -31,6 +31,39 @@ def connection_info(): } +@pytest.fixture(scope="session") +def connection_info_default_credential_chain(): + # Use default authentication (e.g., from environment variables, shared config file, or EC2 instance profile) + access_key = os.getenv("AWS_ACCESS_KEY_ID") + secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") + if not access_key or not secret_key: + pytest.skip( + "Skipping default credential chain test: AWS credentials not set in environment" + ) + return { + "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), + "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "schema_name": "test", + } + + +@pytest.fixture(scope="session") +def connection_info_oidc(): + web_identity_token = os.getenv("TEST_ATHENA_WEB_IDENTITY_TOKEN") + role_arn = os.getenv("TEST_ATHENA_ROLE_ARN") + + if not web_identity_token or not role_arn: + pytest.skip("Skipping OIDC test: web identity token or role ARN not set") + + return { + "s3_staging_dir": os.getenv("TEST_ATHENA_OIDC_S3_STAGING_DIR"), + "region_name": os.getenv("TEST_ATHENA_OIDC_REGION_NAME", "us-west-1"), + "schema_name": "test", + "role_arn": role_arn, + "web_identity_token": web_identity_token, + } + + @pytest.fixture(autouse=True) def set_remote_function_list_path(): config = get_config() diff --git a/ibis-server/tests/routers/v3/connector/athena/test_query.py b/ibis-server/tests/routers/v3/connector/athena/test_query.py index 74a48a3a8..5c9f8ca2a 100644 --- a/ibis-server/tests/routers/v3/connector/athena/test_query.py +++ b/ibis-server/tests/routers/v3/connector/athena/test_query.py @@ -90,7 +90,7 @@ async def test_query(client, manifest_str, connection_info): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[us]", @@ -211,3 +211,52 @@ async def test_query_with_dry_run_and_invalid_sql( ) assert response.status_code == 422 assert response.text is not None + + +@pytest.mark.parametrize( + "conn_fixture", + [ + "connection_info", + "connection_info_default_credential_chain", + "connection_info_oidc", + ], +) +async def test_query_athena_modes(client, manifest_str, request, conn_fixture): + connection_info = request.getfixturevalue(conn_fixture) + + response = await client.post( + url="/v3/connector/athena/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT * FROM wren.public.orders LIMIT 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000", + None, + ] + + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "string", + "totalprice": "decimal128(38, 9)", + "orderdate": "date32[day]", + "order_cust_key": "string", + "timestamp": "timestamp[us]", + "timestamptz": "timestamp[us]", + "test_null_time": "timestamp[us]", + }