Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/cli/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,13 @@ def _run_list_command(service_client: client.ServiceClient, args: argparse.Names
Args:
args: Parsed command line arguments.
"""
params = {'name': args.name, 'user': args.user, 'all_users': args.all, 'buckets': args.bucket,
'count': args.count, 'order': args.order.upper()}
params = {'all_users': args.all, 'count': args.count, 'order': args.order.upper()}
if args.name:
params['name'] = args.name
if args.user:
params['user'] = args.user
if args.bucket:
params['buckets'] = args.bucket
result = service_client.request(
client.RequestMethod.GET,
'api/bucket/list_dataset',
Expand Down
39 changes: 26 additions & 13 deletions src/lib/data/storage/backends/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,16 @@ def data_auth(
self._validate_bucket_access(data_cred=data_cred)
return

# Ambient credentials (IRSA, EC2 instance profile, EKS pod identity, etc.) rely on
# resource-based policies (bucket policies) in addition to identity-based IAM policies.
# SimulatePrincipalPolicy only evaluates identity policies, so it produces false
# negatives when access is granted exclusively via a bucket policy. For ambient
# credentials we fall back to a real head_bucket check, which is the true arbiter
# of whether the caller can reach the bucket.
if isinstance(data_cred, credentials.DefaultDataCredential):
self._validate_bucket_access(data_cred=data_cred)
return
Comment thread
fernandol-nvidia marked this conversation as resolved.

action = []
if access_type == common.AccessType.READ:
action.append('s3:GetObject')
Expand All @@ -493,25 +503,28 @@ def data_auth(
elif access_type == common.AccessType.DELETE:
action.append('s3:DeleteObject')

match data_cred:
case credentials.StaticDataCredential():
session = boto3.Session(
aws_access_key_id=data_cred.access_key_id,
aws_secret_access_key=data_cred.access_key.get_secret_value(),
region_name=self.region(data_cred),
)
case credentials.DefaultDataCredential():
session = boto3.Session(
region_name=self.region(data_cred),
)
case _ as unreachable:
assert_never(unreachable)
session = boto3.Session(
aws_access_key_id=data_cred.access_key_id,
aws_secret_access_key=data_cred.access_key.get_secret_value(),
region_name=self.region(data_cred),
)

iam_client: mypy_boto3_iam.client.IAMClient = session.client('iam')
sts_client: mypy_boto3_sts.client.STSClient = session.client('sts')

def _validate_auth():
arn = sts_client.get_caller_identity()['Arn']
# SimulatePrincipalPolicy requires an IAM role ARN, not an STS
# assumed-role session ARN. Convert when running under IRSA or
# instance roles where STS returns an assumed-role ARN.
# arn:<partition>:sts::<acct>:assumed-role/<path>/<role>/<session>
# -> arn:<partition>:iam::<acct>:role/<path>/<role>
if ':assumed-role/' in arn:
parts = arn.split(':')
partition = parts[1]
account_id = parts[4]
role_path = parts[5].split('assumed-role/', 1)[1].rsplit('/', 1)[0]
arn = f'arn:{partition}:iam::{account_id}:role/{role_path}'
path = f'{self.container}/{self.path if self.path else "*"}'

if path.endswith('/'):
Expand Down
47 changes: 47 additions & 0 deletions src/lib/data/storage/backends/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,53 @@ def test_workflow_config_with_null_credential(self):
# Assert
self.assertIsNone(config.workflow_data.credential)

def test_workflow_config_data_with_default_credential(self):
"""Test WorkflowConfig DataConfig accepts DefaultDataCredential."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://bucket.io/workflows',
region='us-west-2',
)

config = postgres.WorkflowConfig(
workflow_data=postgres.DataConfig(credential=default_cred),
)

credential = config.workflow_data.credential
self.assertIsInstance(credential, credentials.DefaultDataCredential)
assert isinstance(credential, credentials.DefaultDataCredential)
self.assertEqual(credential.endpoint, 's3://bucket.io/workflows')
self.assertEqual(credential.region, 'us-west-2')

def test_workflow_config_log_with_default_credential(self):
"""Test WorkflowConfig LogConfig accepts DefaultDataCredential."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://log-bucket.io/logs',
region='us-east-1',
)

config = postgres.WorkflowConfig(
workflow_log=postgres.LogConfig(credential=default_cred),
)

self.assertIsInstance(
config.workflow_log.credential,
credentials.DefaultDataCredential,
)

def test_default_credential_to_decrypted_dict_no_keys(self):
"""Test DefaultDataCredential.to_decrypted_dict has no access keys."""
default_cred = credentials.DefaultDataCredential(
endpoint='s3://bucket.io/data',
region='us-west-2',
)

result = default_cred.to_decrypted_dict()

self.assertEqual(result['endpoint'], 's3://bucket.io/data')
self.assertEqual(result['region'], 'us-west-2')
self.assertNotIn('access_key_id', result)
self.assertNotIn('access_key', result)


if __name__ == '__main__':
unittest.main()
24 changes: 19 additions & 5 deletions src/runtime/pkg/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ const (
DatasetOperation string = "Dataset"
)

func setOrUnsetEnv(key, value string) {
if value != "" {
os.Setenv(key, value)
} else {
os.Unsetenv(key)
}
}

type VersionInfo struct {
Size int
Checksum string
Expand Down Expand Up @@ -410,12 +418,18 @@ func MountURL(downloadType string, credentialInfo ConfigInfo, urlPath string,
storageBackend := ParseStorageBackend(urlPath)

dataCredential, ok := credentialInfo.Auth.Data[storageBackend.GetProfile()]
if !ok {
osmoChan <- fmt.Sprintf("Missing data credential for %s.", storageBackend.GetProfile())
return isEmpty
if ok {
setOrUnsetEnv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId)
setOrUnsetEnv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey)
setOrUnsetEnv("AWS_REGION", dataCredential.Region)
} else {
// No explicit credential — clear any stale values and let the
// SDK resolve ambient credentials (IRSA, pod identity, etc.).
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Unsetenv("AWS_REGION")
}
os.Setenv("AWS_ACCESS_KEY_ID", dataCredential.AccessKeyId)
os.Setenv("AWS_SECRET_ACCESS_KEY", dataCredential.AccessKey)
os.Unsetenv("AWS_SESSION_TOKEN")

var commandArgs []string

Expand Down
3 changes: 1 addition & 2 deletions src/service/core/data/data_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,7 @@ def get_bucket_info(default_only: bool = False,
path=bucket_info.dataset_path,
description=bucket_info.description,
mode=bucket_info.mode,
default_cred=bucket_info.default_credential is not None\
and bucket_info.default_credential.access_key_id != '')\
default_cred=bucket_info.default_credential is not None)\
for bucket_name, bucket_info in dataset_configs.buckets.items()
}

Expand Down
44 changes: 25 additions & 19 deletions src/utils/connectors/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ def func(new_encrypted: str):
self.execute_commit_command(cmd, (cmd_args[0], new_encrypted) + cmd_args[1:])
return func

def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCredential | None:
def get_data_cred(self, user: str, profile: str) -> credentials.DataCredential | None:
""" Fetch data credentials by profile. """
select_data_cmd = PostgresSelectCommand(
table='credential',
Expand All @@ -1522,13 +1522,7 @@ def get_data_cred(self, user: str, profile: str) -> credentials.StaticDataCreden
bucket_info = storage.construct_storage_backend(bucket.dataset_path)
if bucket_info.profile == profile:
if bucket.default_credential:
return credentials.StaticDataCredential(
region=bucket.region,
access_key_id=bucket.default_credential.access_key_id,
access_key=bucket.default_credential.access_key,
endpoint=bucket_info.profile,
override_url=bucket.default_credential.override_url,
)
return _resolve_bucket_credential(bucket, bucket_info.profile)
break

return None
Expand All @@ -1541,7 +1535,7 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCrede
condition_args=[user, CredentialType.DATA.value])
rows = self.execute_fetch_command(*select_data_cmd.get_args())

user_creds = {
user_creds: Dict[str, credentials.StaticDataCredential] = {
cred.profile: credentials.StaticDataCredential(
endpoint=cred.profile,
**self.decrypt_credential(cred),
Expand All @@ -1553,13 +1547,8 @@ def get_all_data_creds(self, user: str) -> Dict[str, credentials.StaticDataCrede
for bucket in self.get_dataset_configs().buckets.values():
bucket_info = storage.construct_storage_backend(bucket.dataset_path)
if bucket_info.profile not in user_creds and bucket.default_credential:
user_creds[bucket_info.profile] = credentials.StaticDataCredential(
region=bucket.region,
access_key_id=bucket.default_credential.access_key_id,
access_key=bucket.default_credential.access_key,
endpoint=bucket_info.profile,
override_url=bucket.default_credential.override_url,
)
user_creds[bucket_info.profile] = _resolve_bucket_credential(
bucket, bucket_info.profile)
return user_creds

def get_generic_cred(self, user: str, cred_name: str) -> Any:
Expand Down Expand Up @@ -2686,7 +2675,7 @@ def construct_path(endpoint: str, bucket: str, path: str):

class LogConfig(ExtraArgBaseModel):
""" Config for storing information about data. """
credential: credentials.StaticDataCredential | None = None
credential: credentials.DataCredential | None = None


class WorkflowInfo(ExtraArgBaseModel):
Expand All @@ -2703,7 +2692,7 @@ def validate_name(self, name: str):

class DataConfig(ExtraArgBaseModel):
""" Config for storing information about data. """
credential: credentials.StaticDataCredential | None = None
credential: credentials.DataCredential | None = None

base_url: str = ''
# Timeout in mins for osmo-ctrl to retry connecting to the OSMO service until exiting the task
Expand All @@ -2727,6 +2716,23 @@ class BucketMode(enum.Enum):
READ_WRITE = 'read-write'


def _resolve_bucket_credential(
bucket: 'BucketConfig',
profile: str,
) -> credentials.StaticDataCredential:
"""Resolve a bucket's default_credential, rebinding it to the bucket's profile and region."""
credential = bucket.default_credential
if credential is None:
raise ValueError(f'No default credential configured for bucket with profile {profile}')
return credentials.StaticDataCredential(
endpoint=profile,
region=bucket.region,
access_key_id=credential.access_key_id,
access_key=credential.access_key,
override_url=credential.override_url,
)


class BucketConfig(ExtraArgBaseModel):
"""
Class to store the name of the bucket and the dataset path
Expand All @@ -2736,7 +2742,7 @@ class BucketConfig(ExtraArgBaseModel):
description: str = ''
# Mode for read-only or read-write or write-only
mode: str = BucketMode.READ_WRITE.value
# Default cred to use doesn't have one
# Default cred to use is a static credential
# Only applies to workflow operations, NOT user cli since we cannot forward the credential
# to the user
default_credential: credentials.StaticDataCredential | None = None
Expand Down
9 changes: 6 additions & 3 deletions src/utils/job/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import re
import secrets
import time
from collections.abc import Mapping
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from urllib.parse import urlencode

Expand Down Expand Up @@ -102,7 +103,7 @@ def create_login_dict(user: str,


def create_config_dict(
data_info: dict[str, credentials.StaticDataCredential],
data_info: Mapping[str, credentials.StaticDataCredential | credentials.DefaultDataCredential],
) -> dict:
'''
Creates the config dict where the input should be a dict containing key values like:
Expand Down Expand Up @@ -2699,7 +2700,7 @@ def convert_to_pod_spec(
service_config: connectors.ServiceConfig | None = None,
dataset_config: connectors.DatasetConfig | None = None,
pool_info: connectors.Pool | None = None,
data_endpoints: Dict[str, credentials.StaticDataCredential] | None = None,
data_endpoints: Mapping[str, credentials.StaticDataCredential] | None = None,
skip_refresh_token: bool = False,
auth_token: str | None = None,
) -> Tuple[Dict, Dict[str, kb_objects.FileMount], Optional[Tuple[str, str]]]:
Expand Down Expand Up @@ -3193,13 +3194,15 @@ def decode_hstore(tasks: str) -> Set[str]:

def fetch_creds(
user: str,
data_creds: dict[str, credentials.StaticDataCredential],
data_creds: Mapping[str, credentials.StaticDataCredential],
path: str,
disabled_data: list[str] | None = None,
) -> credentials.StaticDataCredential | None:
backend_info = storage.construct_storage_backend(path)

if backend_info.profile not in data_creds:
if backend_info.supports_environment_auth:
return None
if not disabled_data or backend_info.scheme not in disabled_data:
raise osmo_errors.OSMOCredentialError(
f'Could not find {backend_info.profile} credential for user {user}.')
Expand Down
1 change: 1 addition & 0 deletions src/utils/job/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ py_test(
tags = ["manual"],
deps = [
"//src/lib/utils:common",
"//src/lib/utils:credentials",
"//src/utils/connectors",
"//src/utils/job",
]
Expand Down
Loading
Loading