diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/eks.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/eks.py index 91329d8a1e154..70298d66afee1 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/eks.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/eks.py @@ -18,9 +18,10 @@ from __future__ import annotations -import base64 +import contextlib import json import os +import stat import sys import tempfile from collections.abc import Callable, Generator @@ -29,14 +30,12 @@ from functools import partial from botocore.exceptions import ClientError -from botocore.signers import RequestSigner from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.sts import StsHook from airflow.utils import yaml DEFAULT_PAGINATION_TOKEN = "" -STS_TOKEN_EXPIRES_IN = 60 AUTHENTICATION_API_VERSION = "client.authentication.k8s.io/v1alpha1" _POD_USERNAME = "aws" _CONTEXT_NAME = "aws" @@ -79,11 +78,18 @@ class NodegroupStates(Enum): COMMAND = """ export PYTHON_OPERATORS_VIRTUAL_ENV_MODE=1 + + # Source credentials from secure file + source {credentials_file} + output=$({python_executable} -m airflow.providers.amazon.aws.utils.eks_get_token \ - --cluster-name {eks_cluster_name} {args} 2>&1) + --cluster-name {eks_cluster_name} --sts-url '{sts_url}' {args} 2>&1) status=$? + # Clear environment variables after use (defense in depth) + unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY AWS_SESSION_TOKEN + if [ "$status" -ne 0 ]; then printf '%s' "$output" >&2 exit "$status" @@ -537,11 +543,60 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> lis return name_collection + @contextlib.contextmanager + def _secure_credential_context( + self, access_key: str, secret_key: str, session_token: str | None + ) -> Generator[str, None, None]: + """ + Context manager for secure temporary credential file. + + Creates a temporary file with restrictive permissions (0600) containing AWS credentials. + The file is automatically cleaned up when the context manager exits. + + :param access_key: AWS access key ID + :param secret_key: AWS secret access key + :param session_token: AWS session token (optional) + :return: Path to the temporary credential file + """ + fd = None + temp_path = None + + try: + # Create secure temporary file + fd, temp_path = tempfile.mkstemp( + suffix=".aws_creds", + prefix="airflow_eks_", + ) + + # Set restrictive permissions (0600) - owner read/write only + os.fchmod(fd, stat.S_IRUSR | stat.S_IWUSR) + + # Write credentials to secure file + with os.fdopen(fd, "w") as f: + f.write(f"export AWS_ACCESS_KEY_ID='{access_key}'\n") + f.write(f"export AWS_SECRET_ACCESS_KEY='{secret_key}'\n") + if session_token: + f.write(f"export AWS_SESSION_TOKEN='{session_token}'\n") + + fd = None # File handle closed by fdopen + yield temp_path + + finally: + # Cleanup + if fd is not None: + os.close(fd) + if temp_path and os.path.exists(temp_path): + try: + os.unlink(temp_path) + except OSError: + pass # Best effort cleanup + @contextmanager def generate_config_file( self, eks_cluster_name: str, pod_namespace: str | None, + credentials_file, ) -> Generator[str, None, None]: """ Write the kubeconfig file given an EKS Cluster. @@ -553,20 +608,24 @@ def generate_config_file( if self.region_name is not None: args = args + f" --region-name {self.region_name}" - if self.aws_conn_id is not None: - args = args + f" --aws-conn-id {self.aws_conn_id}" - # We need to determine which python executable the host is running in order to correctly # call the eks_get_token.py script. python_executable = f"python{sys.version_info[0]}.{sys.version_info[1]}" # Set up the client eks_client = self.conn + session = self.get_session() # Get cluster details cluster = eks_client.describe_cluster(name=eks_cluster_name) cluster_cert = cluster["cluster"]["certificateAuthority"]["data"] cluster_ep = cluster["cluster"]["endpoint"] + os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional" + try: + sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15" + finally: + del os.environ["AWS_STS_REGIONAL_ENDPOINTS"] + cluster_config = { "apiVersion": "v1", "kind": "Config", @@ -598,6 +657,8 @@ def generate_config_file( "args": [ "-c", COMMAND.format( + credentials_file=credentials_file, + sts_url=sts_url, python_executable=python_executable, eks_cluster_name=eks_cluster_name, args=args, @@ -609,50 +670,10 @@ def generate_config_file( } ], } + config_text = yaml.dump(cluster_config, default_flow_style=False) with tempfile.NamedTemporaryFile(mode="w") as config_file: config_file.write(config_text) config_file.flush() yield config_file.name - - def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str: - session = self.get_session() - service_id = self.conn.meta.service_model.service_id - # This env variable is required so that we get a regionalized endpoint for STS in regions that - # otherwise default to global endpoints. The mechanism below to generate the token is very picky that - # the endpoint is regional. - os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional" - try: - sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15" - finally: - del os.environ["AWS_STS_REGIONAL_ENDPOINTS"] - - signer = RequestSigner( - service_id=service_id, - region_name=session.region_name, - signing_name="sts", - signature_version="v4", - credentials=session.get_credentials(), - event_emitter=session.events, - ) - - request_params = { - "method": "GET", - "url": sts_url, - "body": {}, - "headers": {"x-k8s-aws-id": eks_cluster_name}, - "context": {}, - } - - signed_url = signer.generate_presigned_url( - request_dict=request_params, - region_name=session.region_name, - expires_in=STS_TOKEN_EXPIRES_IN, - operation_name="", - ) - - base64_url = base64.urlsafe_b64encode(signed_url.encode("utf-8")).decode("utf-8") - - # remove any base64 encoding padding: - return "k8s-aws-v1." + base64_url.rstrip("=") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py index 3c03aa3f5010d..86e66fc424419 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/eks.py @@ -1069,10 +1069,17 @@ def execute(self, context: Context): aws_conn_id=self.aws_conn_id, region_name=self.region, ) - with eks_hook.generate_config_file( - eks_cluster_name=self.cluster_name, pod_namespace=self.namespace - ) as self.config_file: - return super().execute(context) + session = eks_hook.get_session() + credentials = session.get_credentials().get_frozen_credentials() + with eks_hook._secure_credential_context( + credentials.access_key, credentials.secret_key, credentials.token + ) as credentials_file: + with eks_hook.generate_config_file( + eks_cluster_name=self.cluster_name, + pod_namespace=self.namespace, + credentials_file=credentials_file, + ) as self.config_file: + return super().execute(context) def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: eks_hook = EksHook( @@ -1081,7 +1088,14 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any: ) eks_cluster_name = event["eks_cluster_name"] pod_namespace = event["namespace"] - with eks_hook.generate_config_file( - eks_cluster_name=eks_cluster_name, pod_namespace=pod_namespace - ) as self.config_file: - return super().trigger_reentry(context, event) + session = eks_hook.get_session() + credentials = session.get_credentials().get_frozen_credentials() + with eks_hook._secure_credential_context( + credentials.access_key, credentials.secret_key, credentials.token + ) as credentials_file: + with eks_hook.generate_config_file( + eks_cluster_name=eks_cluster_name, + pod_namespace=pod_namespace, + credentials_file=credentials_file, + ) as self.config_file: + return super().trigger_reentry(context, event) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py index 9a340671ef56e..47dec133f6c60 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/eks_get_token.py @@ -17,12 +17,16 @@ from __future__ import annotations import argparse +import base64 +import os from datetime import datetime, timedelta, timezone -from airflow.providers.amazon.aws.hooks.eks import EksHook +import boto3 +from botocore.signers import RequestSigner # Presigned STS urls are valid for 15 minutes, set token expiration to 1 minute before it expires for # some cushion +STS_TOKEN_EXPIRES_IN = 60 TOKEN_EXPIRATION_MINUTES = 14 @@ -37,25 +41,59 @@ def get_parser(): parser.add_argument( "--cluster-name", help="The name of the cluster to generate kubeconfig file for.", required=True ) - parser.add_argument( - "--aws-conn-id", - help=( - "The Airflow connection used for AWS credentials. " - "If not specified or empty then the default boto3 behaviour is used." - ), - ) parser.add_argument( "--region-name", help="AWS region_name. If not specified then the default boto3 behaviour is used." ) + parser.add_argument("--sts-url", help="Provide the STS url", required=True) return parser +def fetch_access_token_for_cluster(eks_cluster_name: str, sts_url: str, region_name: str) -> str: + # This will use the credentials from the caller set as the standard AWS env variables + session = boto3.Session(region_name=region_name) + eks_client = session.client("eks") + # This env variable is required so that we get a regionalized endpoint for STS in regions that + # otherwise default to global endpoints. The mechanism below to generate the token is very picky that + # the endpoint is regional. + os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional" + + signer = RequestSigner( + service_id=eks_client.meta.service_model.service_id, + region_name=session.region_name, + signing_name="sts", + signature_version="v4", + credentials=session.get_credentials(), + event_emitter=session.events, + ) + + request_params = { + "method": "GET", + "url": sts_url, + "body": {}, + "headers": {"x-k8s-aws-id": eks_cluster_name}, + "context": {}, + } + + signed_url = signer.generate_presigned_url( + request_dict=request_params, + region_name=session.region_name, + expires_in=STS_TOKEN_EXPIRES_IN, + operation_name="", + ) + + base64_url = base64.urlsafe_b64encode(signed_url.encode("utf-8")).decode("utf-8") + + # remove any base64 encoding padding: + return "k8s-aws-v1." + base64_url.rstrip("=") + + def main(): parser = get_parser() args = parser.parse_args() - eks_hook = EksHook(aws_conn_id=args.aws_conn_id, region_name=args.region_name) - access_token = eks_hook.fetch_access_token_for_cluster(args.cluster_name) + access_token = fetch_access_token_for_cluster( + args.cluster_name, args.sts_url, region_name=args.region_name + ) access_token_expiration = get_expiration_time() print(f"expirationTimestamp: {access_token_expiration}, token: {access_token}") diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py index 651345485e789..fe9a8f4d3135a 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_eks.py @@ -53,7 +53,7 @@ NODEGROUP_NOT_FOUND_MSG, ) -from airflow.providers.amazon.aws.hooks.eks import COMMAND, EksHook +from airflow.providers.amazon.aws.hooks.eks import EksHook from unit.amazon.aws.utils.eks_test_constants import ( DEFAULT_CONN_ID, @@ -1212,47 +1212,14 @@ class TestEksHook: @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn") @pytest.mark.parametrize( - "aws_conn_id, region_name, expected_args", + "aws_conn_id, region_name, expected_region_args", [ - [ - "test-id", - "test-region", - [ - "-c", - COMMAND.format( - python_executable=python_executable, - eks_cluster_name="test-cluster", - args=" --region-name test-region --aws-conn-id test-id", - ), - ], - ], - [ - None, - "test-region", - [ - "-c", - COMMAND.format( - python_executable=python_executable, - eks_cluster_name="test-cluster", - args=" --region-name test-region", - ), - ], - ], - [ - None, - None, - [ - "-c", - COMMAND.format( - python_executable=python_executable, - eks_cluster_name="test-cluster", - args="", - ), - ], - ], + ["test-id", "test-region", " --region-name test-region"], + [None, "test-region", " --region-name test-region"], + [None, None, ""], ], ) - def test_generate_config_file(self, mock_conn, aws_conn_id, region_name, expected_args): + def test_generate_config_file(self, mock_conn, aws_conn_id, region_name, expected_region_args): mock_conn.describe_cluster.return_value = { "cluster": {"certificateAuthority": {"data": "test-cert"}, "endpoint": "test-endpoint"} } @@ -1260,73 +1227,51 @@ def test_generate_config_file(self, mock_conn, aws_conn_id, region_name, expecte # We're mocking all actual AWS calls and don't need a connection. This # avoids an Airflow warning about connection cannot be found. hook.get_connection = lambda _: None + + # Mock credentials file path + credentials_file = "/tmp/test_credentials.aws_creds" + with hook.generate_config_file( - eks_cluster_name="test-cluster", pod_namespace="k8s-namespace" + eks_cluster_name="test-cluster", pod_namespace="k8s-namespace", credentials_file=credentials_file ) as config_file: config = yaml.safe_load(Path(config_file).read_text()) - assert config == { - "apiVersion": "v1", - "kind": "Config", - "clusters": [ - { - "cluster": {"server": "test-endpoint", "certificate-authority-data": "test-cert"}, - "name": "test-cluster", - } - ], - "contexts": [ - { - "context": {"cluster": "test-cluster", "namespace": "k8s-namespace", "user": "aws"}, - "name": "aws", - } - ], - "current-context": "aws", - "preferences": {}, - "users": [ - { - "name": "aws", - "user": { - "exec": { - "apiVersion": "client.authentication.k8s.io/v1alpha1", - "args": expected_args, - "command": "sh", - "interactiveMode": "Never", - } - }, - } - ], - } - @mock.patch("airflow.providers.amazon.aws.hooks.eks.RequestSigner") - @mock.patch("airflow.providers.amazon.aws.hooks.eks.StsHook") - @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn") - @mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_session") - def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn, mock_sts_hook, mock_signer): - mock_signer.return_value.generate_presigned_url.return_value = "http://example.com" - mock_sts_hook.return_value.conn_client_meta.endpoint_url = "https://sts.us-east-1.amazonaws.com" - mock_get_session.return_value.region_name = "us-east-1" - hook = EksHook() - token = hook.fetch_access_token_for_cluster(eks_cluster_name="test-cluster") - mock_signer.assert_called_once_with( - service_id=mock_conn.meta.service_model.service_id, - region_name="us-east-1", - signing_name="sts", - signature_version="v4", - credentials=mock_get_session.return_value.get_credentials.return_value, - event_emitter=mock_get_session.return_value.events, - ) - mock_signer.return_value.generate_presigned_url.assert_called_once_with( - request_dict={ - "method": "GET", - "url": "https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", - "body": {}, - "headers": {"x-k8s-aws-id": "test-cluster"}, - "context": {}, - }, - region_name="us-east-1", - expires_in=60, - operation_name="", - ) - assert token == "k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t" + # Verify basic kubeconfig structure + assert config["apiVersion"] == "v1" + assert config["kind"] == "Config" + assert config["current-context"] == "aws" + + # Verify cluster config + assert len(config["clusters"]) == 1 + cluster = config["clusters"][0] + assert cluster["name"] == "test-cluster" + assert cluster["cluster"]["server"] == "test-endpoint" + assert cluster["cluster"]["certificate-authority-data"] == "test-cert" + + # Verify context config + assert len(config["contexts"]) == 1 + context = config["contexts"][0] + assert context["name"] == "aws" + assert context["context"]["cluster"] == "test-cluster" + assert context["context"]["namespace"] == "k8s-namespace" + assert context["context"]["user"] == "aws" + + # Verify user config uses secure credential approach + assert len(config["users"]) == 1 + user = config["users"][0] + assert user["name"] == "aws" + exec_config = user["user"]["exec"] + assert exec_config["apiVersion"] == "client.authentication.k8s.io/v1alpha1" + assert exec_config["command"] == "sh" + assert exec_config["interactiveMode"] == "Never" + + # Verify the command references a credential file (not inline creds) + command_arg = exec_config["args"][1] # The -c argument content + assert f"source {credentials_file}" in command_arg + + # Verify region arguments are properly included + if expected_region_args: + assert expected_region_args in command_arg # Helper methods for repeated assert combinations. diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py index 875401955de6e..9a2ef21ff7e91 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_eks.py @@ -745,12 +745,39 @@ def test_init_with_region(self): class TestEksPodOperator: @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.execute") @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file") + @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook._secure_credential_context") + @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_session") @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.__init__", return_value=None) def test_existing_nodegroup( - self, mock_eks_hook, mock_generate_config_file, mock_k8s_pod_operator_execute + self, + mock_eks_hook, + mock_get_session, + mock_secure_credential_context, + mock_generate_config_file, + mock_k8s_pod_operator_execute, ): ti_context = mock.MagicMock(name="ti_context") + # Mock the credential chain + mock_session = mock.MagicMock() + mock_credentials = mock.MagicMock() + mock_frozen_credentials = mock.MagicMock() + mock_frozen_credentials.access_key = "test_access_key" + mock_frozen_credentials.secret_key = "test_secret_key" + mock_frozen_credentials.token = "test_token" + + mock_get_session.return_value = mock_session + mock_session.get_credentials.return_value = mock_credentials + mock_credentials.get_frozen_credentials.return_value = mock_frozen_credentials + + # Mock the credential context manager + mock_credentials_file = "/tmp/test_creds.aws_creds" + mock_secure_credential_context.return_value.__enter__.return_value = mock_credentials_file + + # Mock the config file context manager + mock_config_file = "/tmp/test_kubeconfig" + mock_generate_config_file.return_value.__enter__.return_value = mock_config_file + op = EksPodOperator( task_id="run_pod", pod_name="run_pod", @@ -763,13 +790,22 @@ def test_existing_nodegroup( on_finish_action="delete_pod", ) op_return_value = op.execute(ti_context) + + # Verify all the expected calls were made mock_k8s_pod_operator_execute.assert_called_once_with(ti_context) mock_eks_hook.assert_called_once_with(aws_conn_id="aws_default", region_name=None) + mock_get_session.assert_called_once() + mock_session.get_credentials.assert_called_once() + mock_credentials.get_frozen_credentials.assert_called_once() + mock_secure_credential_context.assert_called_once_with( + "test_access_key", "test_secret_key", "test_token" + ) mock_generate_config_file.assert_called_once_with( - eks_cluster_name=CLUSTER_NAME, pod_namespace="default" + eks_cluster_name=CLUSTER_NAME, pod_namespace="default", credentials_file=mock_credentials_file ) + assert mock_k8s_pod_operator_execute.return_value == op_return_value - assert mock_generate_config_file.return_value.__enter__.return_value == op.config_file + assert op.config_file == mock_config_file @pytest.mark.parametrize( "compatible_kpo, kwargs, expected_attributes", @@ -826,10 +862,40 @@ def test_template_fields(self): @mock.patch("airflow.providers.cncf.kubernetes.operators.pod.KubernetesPodOperator.trigger_reentry") @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.generate_config_file") - def test_trigger_reentry(self, mock_generate_config_file, mock_k8s_pod_operator_trigger_reentry): + @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook._secure_credential_context") + @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.get_session") + @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook.__init__", return_value=None) + def test_trigger_reentry( + self, + mock_eks_hook, + mock_get_session, + mock_secure_credential_context, + mock_generate_config_file, + mock_k8s_pod_operator_trigger_reentry, + ): ti_context = mock.MagicMock(name="ti_context") event = {"eks_cluster_name": "eks_cluster_name", "namespace": "namespace"} + # Mock the credential chain + mock_session = mock.MagicMock() + mock_credentials = mock.MagicMock() + mock_frozen_credentials = mock.MagicMock() + mock_frozen_credentials.access_key = "test_access_key" + mock_frozen_credentials.secret_key = "test_secret_key" + mock_frozen_credentials.token = "test_token" + + mock_get_session.return_value = mock_session + mock_session.get_credentials.return_value = mock_credentials + mock_credentials.get_frozen_credentials.return_value = mock_frozen_credentials + + # Mock the credential context manager + mock_credentials_file = "/tmp/test_creds.aws_creds" + mock_secure_credential_context.return_value.__enter__.return_value = mock_credentials_file + + # Mock the config file context manager + mock_config_file = "/tmp/test_kubeconfig" + mock_generate_config_file.return_value.__enter__.return_value = mock_config_file + op = EksPodOperator( task_id="run_pod", pod_name="run_pod", @@ -842,8 +908,18 @@ def test_trigger_reentry(self, mock_generate_config_file, mock_k8s_pod_operator_ on_finish_action="delete_pod", ) op.trigger_reentry(ti_context, event) + + # Verify all the expected calls were made mock_k8s_pod_operator_trigger_reentry.assert_called_once_with(ti_context, event) + mock_get_session.assert_called_once() + mock_session.get_credentials.assert_called_once() + mock_credentials.get_frozen_credentials.assert_called_once() + mock_secure_credential_context.assert_called_once_with( + "test_access_key", "test_secret_key", "test_token" + ) mock_generate_config_file.assert_called_once_with( - eks_cluster_name="eks_cluster_name", pod_namespace="namespace" + eks_cluster_name="eks_cluster_name", + pod_namespace="namespace", + credentials_file=mock_credentials_file, ) - assert mock_generate_config_file.return_value.__enter__.return_value == op.config_file + assert op.config_file == mock_config_file diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py b/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py index 0cb678a4074c0..5f8ed3b6de50d 100644 --- a/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_eks_get_token.py @@ -17,7 +17,6 @@ from __future__ import annotations import contextlib -import runpy from io import StringIO from unittest import mock @@ -28,22 +27,20 @@ class TestGetEksToken: - @mock.patch("airflow.providers.amazon.aws.hooks.eks.EksHook") @time_machine.travel("1995-02-14", tick=False) @pytest.mark.parametrize( - "args, expected_aws_conn_id, expected_region_name", + "args, expected_region_name", [ [ [ "airflow.providers.amazon.aws.utils.eks_get_token", "--region-name", "test-region", - "--aws-conn-id", - "test-id", "--cluster-name", "test-cluster", + "--sts-url", + "https://sts.test-region.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", ], - "test-id", "test-region", ], [ @@ -55,35 +52,103 @@ class TestGetEksToken: "test-region", "--cluster-name", "test-cluster", + "--sts-url", + "https://sts.test-region.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", ], - None, "test-region", ], [ - ["airflow.providers.amazon.aws.utils.eks_get_token", "--cluster-name", "test-cluster"], - None, + [ + "airflow.providers.amazon.aws.utils.eks_get_token", + "--cluster-name", + "test-cluster", + "--sts-url", + "https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + ], None, ], ], ) - def test_run(self, mock_eks_hook, args, expected_aws_conn_id, expected_region_name): - ( - mock_eks_hook.return_value.fetch_access_token_for_cluster.return_value - ) = "k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t" - - with mock.patch("sys.argv", args), contextlib.redirect_stdout(StringIO()) as temp_stdout: - from airflow.providers.amazon.aws.utils import eks_get_token - - eks_get_token_path = eks_get_token.__file__ - # We are not using run_module because of https://github.com/pytest-dev/pytest/issues/9007 - runpy.run_path(eks_get_token_path, run_name="__main__") - output = temp_stdout.getvalue() - token = "token: k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t" - expected_token = output.split(",")[1].strip() - expected_expiration_timestamp = output.split(",")[0].split(":")[1].strip() - assert expected_token == token - assert expected_expiration_timestamp.startswith("1995-02-") - mock_eks_hook.assert_called_once_with( - aws_conn_id=expected_aws_conn_id, region_name=expected_region_name + def test_run(self, args, expected_region_name): + # Instead of trying to mock deep into the CLI execution context, mock the main function itself + with mock.patch( + "airflow.providers.amazon.aws.utils.eks_get_token.fetch_access_token_for_cluster" + ) as mock_fetch_token: + mock_fetch_token.return_value = "k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t" + + with mock.patch("sys.argv", args), contextlib.redirect_stdout(StringIO()) as temp_stdout: + # Import and directly call the main function rather than using runpy + from airflow.providers.amazon.aws.utils.eks_get_token import main + + main() + + output = temp_stdout.getvalue() + token = "token: k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t" + expected_token = output.split(",")[1].strip() + expected_expiration_timestamp = output.split(",")[0].split(":")[1].strip() + assert expected_token == token + assert expected_expiration_timestamp.startswith("1995-02-") + + # Extract the sts-url from args + sts_url = None + if "--sts-url" in args: + sts_url_idx = args.index("--sts-url") + 1 + if sts_url_idx < len(args): + sts_url = args[sts_url_idx] + + # Verify fetch_access_token_for_cluster was called with correct parameters + mock_fetch_token.assert_called_once_with( + "test-cluster", sts_url, region_name=expected_region_name + ) + + @mock.patch("airflow.providers.amazon.aws.utils.eks_get_token.RequestSigner") + @mock.patch("boto3.Session") + def test_fetch_access_token_for_cluster(self, mock_session, mock_signer): + """Test the standalone fetch_access_token_for_cluster function.""" + from airflow.providers.amazon.aws.utils.eks_get_token import fetch_access_token_for_cluster + + # Mock the session and client + mock_session_instance = mock_session.return_value + mock_eks_client = mock_session_instance.client.return_value + mock_session_instance.region_name = "us-east-1" + + # Mock the RequestSigner + mock_signer_instance = mock_signer.return_value + mock_signer_instance.generate_presigned_url.return_value = "http://example.com" + + result = fetch_access_token_for_cluster( + eks_cluster_name="test-cluster", + sts_url="https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + region_name="us-east-1", ) - mock_eks_hook.return_value.fetch_access_token_for_cluster.assert_called_once_with("test-cluster") + + # Verify session creation + mock_session.assert_called_once_with(region_name="us-east-1") + mock_session_instance.client.assert_called_once_with("eks") + + # Verify RequestSigner was called correctly + mock_signer.assert_called_once_with( + service_id=mock_eks_client.meta.service_model.service_id, + region_name="us-east-1", + signing_name="sts", + signature_version="v4", + credentials=mock_session_instance.get_credentials.return_value, + event_emitter=mock_session_instance.events, + ) + + # Verify presigned URL generation + mock_signer_instance.generate_presigned_url.assert_called_once_with( + request_dict={ + "method": "GET", + "url": "https://sts.us-east-1.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15", + "body": {}, + "headers": {"x-k8s-aws-id": "test-cluster"}, + "context": {}, + }, + region_name="us-east-1", + expires_in=60, + operation_name="", + ) + + # Verify the token format + assert result == "k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t"