Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 69 additions & 48 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

from __future__ import annotations

import base64
import contextlib
import json
import os
import stat
import sys
import tempfile
from collections.abc import Callable, Generator
Expand All @@ -29,14 +30,12 @@
from functools import partial

from botocore.exceptions import ClientError
from botocore.signers import RequestSigner

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.utils import yaml

DEFAULT_PAGINATION_TOKEN = ""
STS_TOKEN_EXPIRES_IN = 60
AUTHENTICATION_API_VERSION = "client.authentication.k8s.io/v1alpha1"
_POD_USERNAME = "aws"
_CONTEXT_NAME = "aws"
Expand Down Expand Up @@ -79,11 +78,18 @@ class NodegroupStates(Enum):

COMMAND = """
export PYTHON_OPERATORS_VIRTUAL_ENV_MODE=1

# Source credentials from secure file
source {credentials_file}

output=$({python_executable} -m airflow.providers.amazon.aws.utils.eks_get_token \
--cluster-name {eks_cluster_name} {args} 2>&1)
--cluster-name {eks_cluster_name} --sts-url '{sts_url}' {args} 2>&1)

status=$?

# Clear environment variables after use (defense in depth)
unset AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY AWS_SESSION_TOKEN

if [ "$status" -ne 0 ]; then
printf '%s' "$output" >&2
exit "$status"
Expand Down Expand Up @@ -537,11 +543,60 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> lis

return name_collection

@contextlib.contextmanager
def _secure_credential_context(
self, access_key: str, secret_key: str, session_token: str | None
) -> Generator[str, None, None]:
"""
Context manager for secure temporary credential file.

Creates a temporary file with restrictive permissions (0600) containing AWS credentials.
The file is automatically cleaned up when the context manager exits.

:param access_key: AWS access key ID
:param secret_key: AWS secret access key
:param session_token: AWS session token (optional)
:return: Path to the temporary credential file
"""
fd = None
temp_path = None

try:
# Create secure temporary file
fd, temp_path = tempfile.mkstemp(
suffix=".aws_creds",
prefix="airflow_eks_",
)

# Set restrictive permissions (0600) - owner read/write only
os.fchmod(fd, stat.S_IRUSR | stat.S_IWUSR)

# Write credentials to secure file
with os.fdopen(fd, "w") as f:
f.write(f"export AWS_ACCESS_KEY_ID='{access_key}'\n")
f.write(f"export AWS_SECRET_ACCESS_KEY='{secret_key}'\n")
if session_token:
f.write(f"export AWS_SESSION_TOKEN='{session_token}'\n")

fd = None # File handle closed by fdopen
yield temp_path

finally:
# Cleanup
if fd is not None:
os.close(fd)
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except OSError:
pass # Best effort cleanup

@contextmanager
def generate_config_file(
self,
eks_cluster_name: str,
pod_namespace: str | None,
credentials_file,
) -> Generator[str, None, None]:
"""
Write the kubeconfig file given an EKS Cluster.
Expand All @@ -553,20 +608,24 @@ def generate_config_file(
if self.region_name is not None:
args = args + f" --region-name {self.region_name}"

if self.aws_conn_id is not None:
args = args + f" --aws-conn-id {self.aws_conn_id}"

# We need to determine which python executable the host is running in order to correctly
# call the eks_get_token.py script.
python_executable = f"python{sys.version_info[0]}.{sys.version_info[1]}"
# Set up the client
eks_client = self.conn
session = self.get_session()

# Get cluster details
cluster = eks_client.describe_cluster(name=eks_cluster_name)
cluster_cert = cluster["cluster"]["certificateAuthority"]["data"]
cluster_ep = cluster["cluster"]["endpoint"]

os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
try:
sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
finally:
del os.environ["AWS_STS_REGIONAL_ENDPOINTS"]

cluster_config = {
"apiVersion": "v1",
"kind": "Config",
Expand Down Expand Up @@ -598,6 +657,8 @@ def generate_config_file(
"args": [
"-c",
COMMAND.format(
credentials_file=credentials_file,
sts_url=sts_url,
python_executable=python_executable,
eks_cluster_name=eks_cluster_name,
args=args,
Expand All @@ -609,50 +670,10 @@ def generate_config_file(
}
],
}

config_text = yaml.dump(cluster_config, default_flow_style=False)

with tempfile.NamedTemporaryFile(mode="w") as config_file:
config_file.write(config_text)
config_file.flush()
yield config_file.name

def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:
session = self.get_session()
service_id = self.conn.meta.service_model.service_id
# This env variable is required so that we get a regionalized endpoint for STS in regions that
# otherwise default to global endpoints. The mechanism below to generate the token is very picky that
# the endpoint is regional.
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"
try:
sts_url = f"{StsHook(region_name=session.region_name).conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
finally:
del os.environ["AWS_STS_REGIONAL_ENDPOINTS"]

signer = RequestSigner(
service_id=service_id,
region_name=session.region_name,
signing_name="sts",
signature_version="v4",
credentials=session.get_credentials(),
event_emitter=session.events,
)

request_params = {
"method": "GET",
"url": sts_url,
"body": {},
"headers": {"x-k8s-aws-id": eks_cluster_name},
"context": {},
}

signed_url = signer.generate_presigned_url(
request_dict=request_params,
region_name=session.region_name,
expires_in=STS_TOKEN_EXPIRES_IN,
operation_name="",
)

base64_url = base64.urlsafe_b64encode(signed_url.encode("utf-8")).decode("utf-8")

# remove any base64 encoding padding:
return "k8s-aws-v1." + base64_url.rstrip("=")
Original file line number Diff line number Diff line change
Expand Up @@ -1069,10 +1069,17 @@ def execute(self, context: Context):
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
with eks_hook.generate_config_file(
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
) as self.config_file:
return super().execute(context)
session = eks_hook.get_session()
credentials = session.get_credentials().get_frozen_credentials()
with eks_hook._secure_credential_context(
credentials.access_key, credentials.secret_key, credentials.token
) as credentials_file:
with eks_hook.generate_config_file(
eks_cluster_name=self.cluster_name,
pod_namespace=self.namespace,
credentials_file=credentials_file,
) as self.config_file:
return super().execute(context)

def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
eks_hook = EksHook(
Expand All @@ -1081,7 +1088,14 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
)
eks_cluster_name = event["eks_cluster_name"]
pod_namespace = event["namespace"]
with eks_hook.generate_config_file(
eks_cluster_name=eks_cluster_name, pod_namespace=pod_namespace
) as self.config_file:
return super().trigger_reentry(context, event)
session = eks_hook.get_session()
credentials = session.get_credentials().get_frozen_credentials()
with eks_hook._secure_credential_context(
credentials.access_key, credentials.secret_key, credentials.token
) as credentials_file:
with eks_hook.generate_config_file(
eks_cluster_name=eks_cluster_name,
pod_namespace=pod_namespace,
credentials_file=credentials_file,
) as self.config_file:
return super().trigger_reentry(context, event)
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
from __future__ import annotations

import argparse
import base64
import os
from datetime import datetime, timedelta, timezone

from airflow.providers.amazon.aws.hooks.eks import EksHook
import boto3
from botocore.signers import RequestSigner

# Presigned STS urls are valid for 15 minutes, set token expiration to 1 minute before it expires for
# some cushion
STS_TOKEN_EXPIRES_IN = 60
TOKEN_EXPIRATION_MINUTES = 14


Expand All @@ -37,25 +41,59 @@ def get_parser():
parser.add_argument(
"--cluster-name", help="The name of the cluster to generate kubeconfig file for.", required=True
)
parser.add_argument(
"--aws-conn-id",
help=(
"The Airflow connection used for AWS credentials. "
"If not specified or empty then the default boto3 behaviour is used."
),
)
parser.add_argument(
"--region-name", help="AWS region_name. If not specified then the default boto3 behaviour is used."
)
parser.add_argument("--sts-url", help="Provide the STS url", required=True)

return parser


def fetch_access_token_for_cluster(eks_cluster_name: str, sts_url: str, region_name: str) -> str:
# This will use the credentials from the caller set as the standard AWS env variables
session = boto3.Session(region_name=region_name)
eks_client = session.client("eks")
# This env variable is required so that we get a regionalized endpoint for STS in regions that
# otherwise default to global endpoints. The mechanism below to generate the token is very picky that
# the endpoint is regional.
os.environ["AWS_STS_REGIONAL_ENDPOINTS"] = "regional"

signer = RequestSigner(
service_id=eks_client.meta.service_model.service_id,
region_name=session.region_name,
signing_name="sts",
signature_version="v4",
credentials=session.get_credentials(),
event_emitter=session.events,
)

request_params = {
"method": "GET",
"url": sts_url,
"body": {},
"headers": {"x-k8s-aws-id": eks_cluster_name},
"context": {},
}

signed_url = signer.generate_presigned_url(
request_dict=request_params,
region_name=session.region_name,
expires_in=STS_TOKEN_EXPIRES_IN,
operation_name="",
)

base64_url = base64.urlsafe_b64encode(signed_url.encode("utf-8")).decode("utf-8")

# remove any base64 encoding padding:
return "k8s-aws-v1." + base64_url.rstrip("=")


def main():
parser = get_parser()
args = parser.parse_args()
eks_hook = EksHook(aws_conn_id=args.aws_conn_id, region_name=args.region_name)
access_token = eks_hook.fetch_access_token_for_cluster(args.cluster_name)
access_token = fetch_access_token_for_cluster(
args.cluster_name, args.sts_url, region_name=args.region_name
)
access_token_expiration = get_expiration_time()
print(f"expirationTimestamp: {access_token_expiration}, token: {access_token}")

Expand Down
Loading