From 86ed8f848e93cb0293d04c3e5cb2d54f0b8b78ab Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Sun, 2 Nov 2025 09:56:41 +0800 Subject: [PATCH 1/5] default credencial chaing --- ibis-server/app/model/__init__.py | 52 ++++++++++++++++--- ibis-server/app/model/data_source.py | 48 ++++++++++++++--- .../routers/v3/connector/athena/conftest.py | 17 ++++-- .../routers/v3/connector/athena/test_query.py | 20 +++++++ 4 files changed, 119 insertions(+), 18 deletions(-) diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index c8f62d06b..645142c35 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -118,19 +118,55 @@ class AthenaConnectionInfo(BaseConnectionInfo): description="S3 staging directory for Athena queries", examples=["s3://my-bucket/athena-staging/"], ) - aws_access_key_id: SecretStr = Field( - description="AWS access key ID", examples=["AKIA..."] + + # ── Standard AWS credential chain (optional) ───────────── + aws_access_key_id: SecretStr | None = Field( + description="AWS access key ID. Optional if using IAM role, web identity token, or default credential chain.", + examples=["AKIA..."], + default=None, + ) + aws_secret_access_key: SecretStr | None = Field( + description="AWS secret access key. Optional if using IAM role, web identity token, or default credential chain.", + examples=["my-secret-key"], + default=None, + ) + aws_session_token: SecretStr | None = Field( + description="AWS session token (used for temporary credentials)", + examples=["IQoJb3JpZ2luX2VjEJz//////////wEaCXVzLWVhc3QtMSJHMEUCIQD..."], + default=None, + ) + + # ── Web identity federation (OIDC/JWT-based) ───────────── + web_identity_token: SecretStr | None = Field( + description=( + "OIDC web identity token (JWT) used for AssumeRoleWithWebIdentity authentication. " + "If provided, PyAthena will call STS to exchange it for temporary credentials." + ), + examples=["eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9..."], + default=None, ) - aws_secret_access_key: SecretStr = Field( - description="AWS secret access key", examples=["my-secret-key"] + role_arn: SecretStr | None = Field( + description="The ARN of the role to assume with the web identity token.", + examples=["arn:aws:iam::123456789012:role/YourAthenaRole"], + default=None, ) - region_name: SecretStr = Field( - description="AWS region for Athena", examples=["us-west-2", "us-east-1"] + role_session_name: SecretStr | None = Field( + description="The session name when assuming a role (optional).", + examples=["PyAthena-session"], + default=None, + ) + + # ── Regional and database settings ─────────────────────── + region_name: SecretStr | None = Field( + description="AWS region for Athena. Optional; will use default region if not provided.", + examples=["us-west-2", "us-east-1"], + default=None, ) - schema_name: SecretStr = Field( + schema_name: SecretStr | None = Field( alias="schema_name", - description="The database name in Athena", + description="The database name in Athena. Defaults to 'default'.", examples=["default"], + default=SecretStr("default"), ) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index e26c20cf9..5fe88bd73 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -254,13 +254,47 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend: @staticmethod def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: - return ibis.athena.connect( - s3_staging_dir=info.s3_staging_dir.get_secret_value(), - aws_access_key_id=info.aws_access_key_id.get_secret_value(), - aws_secret_access_key=info.aws_secret_access_key.get_secret_value(), - region_name=info.region_name.get_secret_value(), - schema_name=info.schema_name.get_secret_value(), - ) + kwargs: dict[str, Any] = { + "s3_staging_dir": info.s3_staging_dir.get_secret_value(), + "schema_name": info.schema_name.get_secret_value(), + } + + # ── Region ──────────────────────────────────────────────── + if info.region_name: + kwargs["region_name"] = info.region_name.get_secret_value() + + # ── Optional WorkGroup & Result Reuse ───────────────────── + if info.work_group: + kwargs["work_group"] = info.work_group.get_secret_value() + if info.result_reuse_enable is not None: + kwargs["result_reuse_enable"] = info.result_reuse_enable + if info.result_reuse_minutes is not None: + kwargs["result_reuse_minutes"] = info.result_reuse_minutes + + # ── 1️⃣ Web Identity Token flow (Google OIDC → AWS STS) ─── + if info.web_identity_token and info.role_arn: + kwargs["web_identity_token"] = info.web_identity_token.get_secret_value() + kwargs["role_arn"] = info.role_arn.get_secret_value() + if info.role_session_name: + kwargs["role_session_name"] = info.role_session_name.get_secret_value() + + # ── 2️⃣ Standard Access/Secret Keys ─────────────────────── + elif info.aws_access_key_id and info.aws_secret_access_key: + kwargs["aws_access_key_id"] = info.aws_access_key_id.get_secret_value() + kwargs["aws_secret_access_key"] = ( + info.aws_secret_access_key.get_secret_value() + ) + if info.aws_session_token: + kwargs["aws_session_token"] = info.aws_session_token.get_secret_value() + + # ── 3️⃣ Default AWS credential chain ─────────────────────── + # Nothing needed — PyAthena automatically falls back to: + # - Environment variables + # - ~/.aws/credentials + # - IAM Role (EC2, ECS, Lambda) + + # Now connect via Ibis wrapper + return ibis.athena.connect(**kwargs) @staticmethod def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend: diff --git a/ibis-server/tests/routers/v3/connector/athena/conftest.py b/ibis-server/tests/routers/v3/connector/athena/conftest.py index eeb5f7db1..923a0ee7c 100644 --- a/ibis-server/tests/routers/v3/connector/athena/conftest.py +++ b/ibis-server/tests/routers/v3/connector/athena/conftest.py @@ -21,13 +21,24 @@ def pytest_collection_modifyitems(items): @pytest.fixture(scope="session") -def connection_info(): +def connection_info_oidc(): + """Use web identity token (OIDC → AssumeRoleWithWebIdentity) authentication.""" + token_path = os.getenv("TEST_ATHENA_WEB_IDENTITY_TOKEN_PATH") + role_arn = os.getenv("TEST_ATHENA_ROLE_ARN") + + if not token_path or not role_arn: + pytest.skip("Skipping OIDC test: web identity token or role ARN not set") + + with open(token_path, encoding="utf-8") as f: + web_identity_token = f.read().strip() + return { "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), - "aws_access_key_id": os.getenv("TEST_ATHENA_AWS_ACCESS_KEY_ID"), - "aws_secret_access_key": os.getenv("TEST_ATHENA_AWS_SECRET_ACCESS_KEY"), "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), "schema_name": "test", + "role_arn": role_arn, + "role_session_name": "pytest-session", + "web_identity_token": web_identity_token, } diff --git a/ibis-server/tests/routers/v3/connector/athena/test_query.py b/ibis-server/tests/routers/v3/connector/athena/test_query.py index 74a48a3a8..3565fc898 100644 --- a/ibis-server/tests/routers/v3/connector/athena/test_query.py +++ b/ibis-server/tests/routers/v3/connector/athena/test_query.py @@ -211,3 +211,23 @@ async def test_query_with_dry_run_and_invalid_sql( ) assert response.status_code == 422 assert response.text is not None + + +@pytest.mark.parametrize( + "conn_fixture", ["connection_info_static", "connection_info_oidc"] +) +async def test_query_athena_modes(client, manifest_str, request, conn_fixture): + connection_info = request.getfixturevalue(conn_fixture) + + response = await client.post( + url="/v3/connector/athena/query", + json={ + "connectionInfo": connection_info, + "manifestStr": manifest_str, + "sql": "SELECT 1", + }, + ) + assert response.status_code == 200 + result = response.json() + assert "columns" in result + assert "data" in result From a308f4e1391d7ef2b3428333c9f40325673705fe Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Tue, 4 Nov 2025 14:02:07 +0800 Subject: [PATCH 2/5] fix tests --- ibis-server/app/model/__init__.py | 2 +- ibis-server/app/model/data_source.py | 8 ------ .../tests/routers/v2/connector/test_athena.py | 6 ++-- .../routers/v3/connector/athena/conftest.py | 28 ++++++++++++++++++- .../routers/v3/connector/athena/test_query.py | 6 ++-- 5 files changed, 34 insertions(+), 16 deletions(-) diff --git a/ibis-server/app/model/__init__.py b/ibis-server/app/model/__init__.py index 645142c35..bc900501f 100644 --- a/ibis-server/app/model/__init__.py +++ b/ibis-server/app/model/__init__.py @@ -157,7 +157,7 @@ class AthenaConnectionInfo(BaseConnectionInfo): ) # ── Regional and database settings ─────────────────────── - region_name: SecretStr | None = Field( + region_name: SecretStr = Field( description="AWS region for Athena. Optional; will use default region if not provided.", examples=["us-west-2", "us-east-1"], default=None, diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 5fe88bd73..0785473e7 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -263,14 +263,6 @@ def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: if info.region_name: kwargs["region_name"] = info.region_name.get_secret_value() - # ── Optional WorkGroup & Result Reuse ───────────────────── - if info.work_group: - kwargs["work_group"] = info.work_group.get_secret_value() - if info.result_reuse_enable is not None: - kwargs["result_reuse_enable"] = info.result_reuse_enable - if info.result_reuse_minutes is not None: - kwargs["result_reuse_minutes"] = info.result_reuse_minutes - # ── 1️⃣ Web Identity Token flow (Google OIDC → AWS STS) ─── if info.web_identity_token and info.role_arn: kwargs["web_identity_token"] = info.web_identity_token.get_secret_value() diff --git a/ibis-server/tests/routers/v2/connector/test_athena.py b/ibis-server/tests/routers/v2/connector/test_athena.py index 027776a54..7cf6d002b 100644 --- a/ibis-server/tests/routers/v2/connector/test_athena.py +++ b/ibis-server/tests/routers/v2/connector/test_athena.py @@ -59,7 +59,7 @@ }, { "name": "timestamptz", - "expression": "TIMESTAMP '2024-01-01 23:59:59 UTC'", + "expression": "CAST(TIMESTAMP '2024-01-01 23:59:59 UTC' AS timestamp)", "type": "timestamp", }, { @@ -113,7 +113,7 @@ async def test_query(client, manifest_str): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[us]", @@ -153,7 +153,7 @@ async def test_query_glue_database(client, manifest_str): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[us]", diff --git a/ibis-server/tests/routers/v3/connector/athena/conftest.py b/ibis-server/tests/routers/v3/connector/athena/conftest.py index 923a0ee7c..353c0738a 100644 --- a/ibis-server/tests/routers/v3/connector/athena/conftest.py +++ b/ibis-server/tests/routers/v3/connector/athena/conftest.py @@ -20,6 +20,33 @@ def pytest_collection_modifyitems(items): item.add_marker(pytestmark) +@pytest.fixture(scope="session") +def connection_info(): + return { + "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), + "aws_access_key_id": os.getenv("TEST_ATHENA_AWS_ACCESS_KEY_ID"), + "aws_secret_access_key": os.getenv("TEST_ATHENA_AWS_SECRET_ACCESS_KEY"), + "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "schema_name": "test", + } + + +@pytest.fixture(scope="session") +def connection_info_default_credential_chain(): + # Use default authentication (e.g., from environment variables, shared config file, or EC2 instance profile) + access_key = os.getenv("AWS_ACCESS_KEY_ID") + secret_key = os.getenv("AWS_SECRET_ACCESS_KEY") + if not access_key or not secret_key: + pytest.skip( + "Skipping default credential chain test: AWS credentials not set in environment" + ) + return { + "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), + "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "schema_name": "test", + } + + @pytest.fixture(scope="session") def connection_info_oidc(): """Use web identity token (OIDC → AssumeRoleWithWebIdentity) authentication.""" @@ -37,7 +64,6 @@ def connection_info_oidc(): "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), "schema_name": "test", "role_arn": role_arn, - "role_session_name": "pytest-session", "web_identity_token": web_identity_token, } diff --git a/ibis-server/tests/routers/v3/connector/athena/test_query.py b/ibis-server/tests/routers/v3/connector/athena/test_query.py index 3565fc898..dba126822 100644 --- a/ibis-server/tests/routers/v3/connector/athena/test_query.py +++ b/ibis-server/tests/routers/v3/connector/athena/test_query.py @@ -90,7 +90,7 @@ async def test_query(client, manifest_str, connection_info): "orderkey": "int64", "custkey": "int64", "orderstatus": "string", - "totalprice": "decimal128(15, 2)", + "totalprice": "decimal128(38, 9)", "orderdate": "date32[day]", "order_cust_key": "string", "timestamp": "timestamp[us]", @@ -214,7 +214,7 @@ async def test_query_with_dry_run_and_invalid_sql( @pytest.mark.parametrize( - "conn_fixture", ["connection_info_static", "connection_info_oidc"] + "conn_fixture", ["connection_info", "connection_info_default_credential_chain"] ) async def test_query_athena_modes(client, manifest_str, request, conn_fixture): connection_info = request.getfixturevalue(conn_fixture) @@ -224,7 +224,7 @@ async def test_query_athena_modes(client, manifest_str, request, conn_fixture): json={ "connectionInfo": connection_info, "manifestStr": manifest_str, - "sql": "SELECT 1", + "sql": "SELECT * FROM wren.public.orders LIMIT 1", }, ) assert response.status_code == 200 From 5b7e9b8db9b7fe4e5bb754826035e2d904f6c584 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Tue, 4 Nov 2025 14:14:34 +0800 Subject: [PATCH 3/5] complete test --- .../routers/v3/connector/athena/test_query.py | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/ibis-server/tests/routers/v3/connector/athena/test_query.py b/ibis-server/tests/routers/v3/connector/athena/test_query.py index dba126822..6325d5f43 100644 --- a/ibis-server/tests/routers/v3/connector/athena/test_query.py +++ b/ibis-server/tests/routers/v3/connector/athena/test_query.py @@ -229,5 +229,29 @@ async def test_query_athena_modes(client, manifest_str, request, conn_fixture): ) assert response.status_code == 200 result = response.json() - assert "columns" in result - assert "data" in result + assert len(result["columns"]) == len(manifest["models"][0]["columns"]) + assert len(result["data"]) == 1 + + assert result["data"][0] == [ + 1, + 36901, + "O", + "173665.47", + "1996-01-02", + "1_36901", + "2024-01-01 23:59:59.000000", + "2024-01-01 23:59:59.000000", + None, + ] + + assert result["dtypes"] == { + "orderkey": "int64", + "custkey": "int64", + "orderstatus": "string", + "totalprice": "decimal128(38, 9)", + "orderdate": "date32[day]", + "order_cust_key": "string", + "timestamp": "timestamp[us]", + "timestamptz": "timestamp[us]", + "test_null_time": "timestamp[us]", + } From ffd94a53f27a84d8ceb9409655d040588842877a Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Fri, 14 Nov 2025 09:38:54 +0800 Subject: [PATCH 4/5] oidc test --- ibis-server/app/model/data_source.py | 28 +++++++++++++++---- .../routers/v3/connector/athena/conftest.py | 12 +++----- .../routers/v3/connector/athena/test_query.py | 7 ++++- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 0785473e7..0bdf1837f 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -8,6 +8,7 @@ from typing import Any from urllib.parse import unquote_plus +import boto3 import ibis from google.cloud import bigquery from google.oauth2 import service_account @@ -263,14 +264,29 @@ def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: if info.region_name: kwargs["region_name"] = info.region_name.get_secret_value() - # ── 1️⃣ Web Identity Token flow (Google OIDC → AWS STS) ─── + # ── Web Identity Token flow (Google OIDC → AWS STS) ─── if info.web_identity_token and info.role_arn: - kwargs["web_identity_token"] = info.web_identity_token.get_secret_value() - kwargs["role_arn"] = info.role_arn.get_secret_value() - if info.role_session_name: - kwargs["role_session_name"] = info.role_session_name.get_secret_value() + oidc_token = info.web_identity_token.get_secret_value() + role_arn = info.role_arn.get_secret_value() + session_name = ( + info.role_session_name.get_secret_value() + if info.role_session_name + else "wren-oidc-session" + ) + sts = boto3.client("sts", region_name=info.region_name.get_secret_value()) + + resp = sts.assume_role_with_web_identity( + RoleArn=role_arn, + RoleSessionName=session_name, + WebIdentityToken=oidc_token, + ) + + creds = resp["Credentials"] + kwargs["aws_access_key_id"] = creds["AccessKeyId"] + kwargs["aws_secret_access_key"] = creds["SecretAccessKey"] + kwargs["aws_session_token"] = creds["SessionToken"] - # ── 2️⃣ Standard Access/Secret Keys ─────────────────────── + # ── Standard Access/Secret Keys ─────────────────────── elif info.aws_access_key_id and info.aws_secret_access_key: kwargs["aws_access_key_id"] = info.aws_access_key_id.get_secret_value() kwargs["aws_secret_access_key"] = ( diff --git a/ibis-server/tests/routers/v3/connector/athena/conftest.py b/ibis-server/tests/routers/v3/connector/athena/conftest.py index 353c0738a..7c4c6febd 100644 --- a/ibis-server/tests/routers/v3/connector/athena/conftest.py +++ b/ibis-server/tests/routers/v3/connector/athena/conftest.py @@ -49,19 +49,15 @@ def connection_info_default_credential_chain(): @pytest.fixture(scope="session") def connection_info_oidc(): - """Use web identity token (OIDC → AssumeRoleWithWebIdentity) authentication.""" - token_path = os.getenv("TEST_ATHENA_WEB_IDENTITY_TOKEN_PATH") + web_identity_token = os.getenv("TEST_ATHENA_WEB_IDENTITY_TOKEN") role_arn = os.getenv("TEST_ATHENA_ROLE_ARN") - if not token_path or not role_arn: + if not web_identity_token or not role_arn: pytest.skip("Skipping OIDC test: web identity token or role ARN not set") - with open(token_path, encoding="utf-8") as f: - web_identity_token = f.read().strip() - return { - "s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"), - "region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"), + "s3_staging_dir": os.getenv("TEST_ATHENA_OIDC_S3_STAGING_DIR"), + "region_name": os.getenv("TEST_ATHENA_OIDC_REGION_NAME", "us-west-1"), "schema_name": "test", "role_arn": role_arn, "web_identity_token": web_identity_token, diff --git a/ibis-server/tests/routers/v3/connector/athena/test_query.py b/ibis-server/tests/routers/v3/connector/athena/test_query.py index 6325d5f43..5c9f8ca2a 100644 --- a/ibis-server/tests/routers/v3/connector/athena/test_query.py +++ b/ibis-server/tests/routers/v3/connector/athena/test_query.py @@ -214,7 +214,12 @@ async def test_query_with_dry_run_and_invalid_sql( @pytest.mark.parametrize( - "conn_fixture", ["connection_info", "connection_info_default_credential_chain"] + "conn_fixture", + [ + "connection_info", + "connection_info_default_credential_chain", + "connection_info_oidc", + ], ) async def test_query_athena_modes(client, manifest_str, request, conn_fixture): connection_info = request.getfixturevalue(conn_fixture) From 44655e3d0f13eb0c6f927d4da70a13318e473738 Mon Sep 17 00:00:00 2001 From: DouEnergy Date: Fri, 14 Nov 2025 10:33:43 +0800 Subject: [PATCH 5/5] check region --- ibis-server/app/model/data_source.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ibis-server/app/model/data_source.py b/ibis-server/app/model/data_source.py index 0bdf1837f..31e068e18 100644 --- a/ibis-server/app/model/data_source.py +++ b/ibis-server/app/model/data_source.py @@ -273,7 +273,8 @@ def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend: if info.role_session_name else "wren-oidc-session" ) - sts = boto3.client("sts", region_name=info.region_name.get_secret_value()) + region = info.region_name.get_secret_value() if info.region_name else None + sts = boto3.client("sts", region_name=region) resp = sts.assume_role_with_web_identity( RoleArn=role_arn,